classdef DynamicsReducedForm < SteadyStateReducedForm
    properties (GetAccess = public, SetAccess = public)
        ct                  % Consumption policy function at time t
        at                  % Savings policy function at time t (assets you start with next period)
        Rt                  % Nominal interest rate sequence
        Yt                  % Aggregate Output
        qt                  % Asset price
        t_change            % Time of the experiment
    end
    
    methods
                
        function obj = DynamicsReducedForm(calibr)
            % Constructor. pass it a calibration structure 
            obj  = obj@SteadyStateReducedForm(calibr);
            disp('Steady state found successfully');
        end
        
        function one_period_shock(obj, t_change, shock, T, K)
            K = K+1;
            % Policy functions
            obj.ct = cell(T,K);
            obj.ct(T,:) = {obj.cc};
            obj.at = cell(T,K);
            obj.at(T,:) = {obj.aa};
            
            % Relevant sequences to calculate policy functions
            obj.Rt = obj.R*ones(T,1);
            obj.Rt(t_change) = obj.Rt(t_change)*shock;
            obj.t_change = t_change;
            
            obj.Yt = ones(T,K)*obj.Y;
            
            obj.qt = ones(T,K);
            obj.qt(:,1) = obj.q;
            
            for k = 2:K
                disp(['K=', num2str(k-1)]);

                disp('Calculating policy functions backwards...');
                for t = T-1:-1:2
                    [obj.ct{t,k}, obj.at{t,k}] = obj.get_policy_today(t, k, obj.Yt(t,k-1));
                end
                
                disp('Moving economy forward...');
                distr = obj.invariant;

                for t=1:T-1
                    % Calculate equilibrium quantities
                    if obj.verbose && ~mod(t,20)
                        disp(['t=', num2str(t)]);
                    end
                                   
                    eq_cond = @(x) obj.equilibrium_conditions(t, k, distr, x);
                    
                    res = fsolve(eq_cond, obj.Yt(t,k-1), optimoptions('fsolve','disp','none')); %,'UseParallel',true));
                    
                    obj.Yt(t,k) = res;
                                     
                    cumR = cumprod(obj.Rt(t:end));
                    obj.qt(t,k) = obj.delta*(obj.Yt(t+1:end,k-1)'*(1./cumR(1:end-1))) + obj.Y*obj.R/(cumR(end)*(obj.R-1));
                    
                    % Set the corresponding policy functions for time t level k
                    [obj.ct{t,k}, obj.at{t,k}] = obj.get_policy_today(t, k, obj.Yt(t,k));
                    
                    % Simulate next period distribution
                    distr = obj.simulate_economy_forward(t, k, distr);
                end
            end
            % Save memory
            obj.ct(:)=[];
            obj.at(:)=[];
        end
        
    end

    
    methods(Access = public)
        function res = equilibrium_conditions(obj, t, k, distr, Yt_guess)
            
            [c_pol, ~] = obj.get_policy_today(t, k, Yt_guess);
            
            cy = sum(sum(c_pol.*distr)) - Yt_guess;
            res=cy;
        end
        
        function [c_pol_new, a_pol_new] = get_policy_today(obj, t, k, Y_loop)
            %% Calculate policy function today given policy function tomorrow
            % Calculate policy function from consumption problem
            c_pol = obj.ct{t+1,k};
            cumR = cumprod(obj.Rt(t:end));
            revaluation=0;
            if t==1
                Y_exp = obj.Y;
                R_last = obj.R;
                q_exp = obj.q;
                % Initial revaluation factor
                q_new = obj.delta*(obj.Yt(t+1:end,k-1)'*(1./cumR(1:end-1)) + obj.Y*obj.R/(cumR(end)*(obj.R-1)));
                revaluation = 1/R_last/obj.q*(q_new-q_exp);
            else
                Y_exp = obj.Yt(t,k-1);
                R_last = obj.Rt(t-1);
                q_exp = obj.qt(t-1,k);
            end
            % Dividend update factor
            dividend = 1/R_last/q_exp*obj.delta*(Y_loop-Y_exp);
            
            c_pol_new         = zeros(obj.agrid,obj.Nstate);
            a_pol_new         = zeros(obj.agrid,obj.Nstate);
            
            parfor state=1:obj.Nstate
                c_assoc_loop=zeros(obj.agrid,1);
                a_assoc_loop=zeros(obj.agrid,1);
                
                for i=1:obj.agrid    %parfor
                    
                    marguprime_loop = uprime(c_pol(i,:),obj.calibration)*obj.pp(:,state);
                    
                    c_assoc_loop(i) = uprimeinv(obj.beta*obj.Rt(t)*marguprime_loop,obj.calibration);
                    a_assoc_loop(i) = (obj.a(i)/obj.Rt(t) + c_assoc_loop(i) - obj.ss(state)*(1-obj.delta)*Y_loop )/(1 + dividend + revaluation);
                end
                
                ind_low=(obj.a<a_assoc_loop(1));
                ind_high=(obj.a>a_assoc_loop(end));
                
                c_extrap_low          = zeros(obj.agrid,1);
                
                for i = 1:sum(ind_low)
                    c_extrap_low(i) = fzero(@(c) obj.a(i)*(1 + dividend + revaluation) + obj.ss(state)*(1-obj.delta)*Y_loop - obj.alow/obj.Rt(t) - c, [0.0001, c_assoc_loop(1)]);
                end
                
                c_extrap_high         = exp(log(c_assoc_loop(end))+(log(c_assoc_loop(end))-log(c_assoc_loop(end-1)))/(a_assoc_loop(end)-a_assoc_loop(end-1))*(obj.a-a_assoc_loop(end)));
                
                c_pol_new(:,state)         = (ones(obj.agrid,1)-ind_low).*(ones(obj.agrid,1)-ind_high).*interp1(a_assoc_loop,c_assoc_loop,obj.a,'linear','extrap')+c_extrap_low.*ind_low+c_extrap_high.*ind_high;
                a_pol_new(:,state)         = (obj.a*(1 + dividend + revaluation) + obj.ss(state)*(1-obj.delta)*Y_loop - c_pol_new(:,state))*obj.Rt(t);
            end
        end
                
        function new_distr = simulate_economy_forward(obj, t, k, distr)
            %% Returns next period distribution using savings policy function for time t level k
			h = (obj.a_NL(end)-obj.a_NL(1))/(obj.agrid-1);
            indices_a_next = zeros(2*obj.agrid*obj.Nstate^2,1);
            p_values  = zeros(2*obj.agrid*obj.Nstate^2,1);
            
            for i=1:obj.Nstate
				a_next = min(max(floor((((obj.at{t,k}(:,i) - obj.a(1))/(obj.a(end)-obj.a(1))).^(1/obj.grid_NL))/h)' + 1, 1), obj.agrid-1);
				u = ((obj.at{t,k}(:,i) - obj.a(a_next))./(obj.a(a_next+1) - obj.a(a_next)))';
                u = repmat(u, obj.Nstate,1);
                u = [(1-u(:))'; u(:)'];
                a_next = repmat(a_next,obj.Nstate,1) + obj.agrid*repmat((0:(obj.Nstate-1))',1,obj.agrid);
                a_next = [a_next(:)'; a_next(:)'+1]; % mix them in alternating order once (:) is applied
                indices_a_next(((i-1)*2*obj.agrid*obj.Nstate+1):i*2*obj.agrid*obj.Nstate) = a_next(:);
                p_values(((i-1)*2*obj.agrid*obj.Nstate+1):i*2*obj.agrid*obj.Nstate) = u(:).*kron(repmat(obj.pp(:,i),obj.agrid,1), ones(2,1));
            end
            
            x_trans = sparse(kron(1:(obj.agrid*obj.Nstate),ones(1,2*obj.Nstate)),indices_a_next,p_values,obj.agrid*obj.Nstate,obj.agrid*obj.Nstate);
            
            new_distr = x_trans'*distr(:);
            new_distr = reshape(new_distr,obj.agrid,obj.Nstate);
        end

    end
end