classdef DynamicsInfUpstream < SteadyStateUpstream
    properties (GetAccess = public, SetAccess = public)
        ct                  % Consumption policy function at time t
        nt                  % Labor supply policy function at time t
        at                  % Savings policy function at time t (assets you start with next period)
        Rt                  % Nominal interest rate sequence
        Nt                  % Employment sequence
        Ntildet             % Intermediate aggregate index sequence
        wt                  % Nominal Wages
        pt                  % Price level
        qt                  % Asset price
        Deltat              % Price dispersion distorsion
        p_upstrt            % Price index for the intermediate good
        t_change            % Time of the experiment
        pt_realized         % Price level realized
        wt_realized         % Nominal Wage realized
        p_upstrt_realized   % Intermediate prices realized
        taylor_inf = 1.5    % Taylor inflation coefficient
    end
    
    properties (Dependent=true, SetAccess=public)
        Yt                  % Aggregate Output
        inflt               % Inflation rate
    end
    
    methods
        
        function res = get.Yt(obj)
            res = obj.Nt.^(1-obj.delta)./obj.Deltat;
        end
        
        function res = get.inflt(obj)
            res = [obj.pt_realized(1,:); obj.pt_realized(2:end,:)./obj.pt_realized(1:end-1,:)];
        end
        
        function obj = DynamicsInfUpstream(calibr)
            % Constructor. pass it a calibration structure
            obj  = obj@SteadyStateUpstream(calibr);
            obj.taylor_inf = calibr.taylor_inf;
            disp('Steady state found successfully');
            
        end
        
        function one_period_shock(obj, t_change, shock, T, K)
            K = K+1;
            % Policy functions
            obj.ct = cell(T,K);
            obj.ct(T,:) = {obj.cc};
            obj.nt = cell(T,K);
            obj.nt(T,:) = {obj.nn};
            obj.at = cell(T,K);
            obj.at(T,:) = {obj.aa};
            
            % Relevant sequences to calculate policy functions
            obj.Rt = obj.R*ones(T,1);
            obj.Rt(t_change) = obj.Rt(t_change)*shock;
            obj.t_change = t_change;
            
            obj.Nt = ones(T,K)*obj.N;
            
            obj.wt = ones(T,K)*obj.w;
            obj.wt_realized = ones(T,K)*obj.w;
            
            obj.pt = ones(T,K);
            obj.pt_realized = ones(T,K);
                        
            obj.p_upstrt = ones(T,K)*obj.p_upstr;
            obj.p_upstrt_realized = ones(T,K)*obj.p_upstr;
            
            obj.qt = ones(T,K);
            obj.qt(:,1) = obj.q;
            
            obj.Deltat = ones(T,K)*obj.Delta;
            
            for k = 2:K
                disp(['K=', num2str(k-1)]);
                
                % Impose Taylor Rule after the experiment
                obj.Rt(t_change+1:end) = obj.R*(obj.pt(t_change+1:end,k-1)./obj.pt(t_change:end-1,k-1)).^obj.taylor_inf;
                obj.Deltat(:,k) = obj.Deltat(:,k-1);
                
                disp('Calculating policy functions backwards...');
                obj.pt(:,k) = obj.pt(:,k-1);
                for t = T-1:-1:2
                    [obj.ct{t,k}, obj.nt{t,k}, obj.at{t,k}] = obj.get_policy_today(t, k, obj.Nt(t,k-1), obj.wt(t,k-1)/obj.pt(t,k-1));
                end
                
                disp('Moving economy forward...');
                distr = obj.invariant;
                for t=1:T-1
                    % Calculate equilibrium prices given the distribution
                    % and expected prices for the future
                    if obj.verbose && ~mod(t,20)
                        disp(['t=', num2str(t)]);
                    end
                    
                    % Calculate inflation (it is predetermined in t-1)
                    cumR = cumprod([1; obj.Rt(t:end-1)]);
                    cumlambda = obj.lambda.^(0:length(cumR)-1);
                    if t==1
                        numerator = obj.p_upstr;
                        denominator = 1;
                    else
                        numerator = cumlambda(1:end-1)./cumR(1:end-1)'*(obj.Yt(t:end-1,k-1).*(obj.pt(t:end-1,k-1)).^obj.epsilon.*obj.p_upstrt(t:end-1,k-1)) + obj.p_upstrt(end,k-1)*obj.pt(end,k-1)^obj.epsilon*obj.Y*obj.R/(obj.R-obj.lambda)*cumlambda(end)/cumR(end);
                        denominator = cumlambda(1:end-1)./cumR(1:end-1)'*(obj.Yt(t:end-1,k-1).*(obj.pt(t:end-1,k-1)).^obj.epsilon) + obj.pt(end,k-1)^obj.epsilon*obj.Y*obj.R/(obj.R-obj.lambda)*cumlambda(end)/cumR(end);
                    end
                    pstar = obj.epsilon/(obj.epsilon-1)*numerator/denominator;
                    if t==1
                        p_last = 1;
                        Delta_last = obj.Delta;
                    else
                        p_last = obj.pt(t-1,k);
                        Delta_last = obj.Deltat(t-1,k);
                    end
                    obj.pt(t,k) = (obj.lambda*p_last^(1-obj.epsilon) + (1-obj.lambda)*pstar^(1-obj.epsilon))^(1/(1-obj.epsilon));
                    infl = obj.pt(t,k)/p_last; % Intermediate price index inflation
                    if obj.lambda ~= 1
                        obj.Deltat(t,k) = obj.lambda*Delta_last*infl^obj.epsilon + (1-obj.lambda)*(max(1-obj.lambda*infl^(obj.epsilon-1),0)/(1-obj.lambda))^(obj.epsilon/(obj.epsilon-1));
                    else
                        obj.Delta = 1; % Rigid prices
                    end
                    
                    % Impose Taylor Rule once inflation is realized
                    if t>obj.t_change
                        obj.Rt(t) = obj.R*infl^obj.taylor_inf;
                    end
                    
                    % Update price sequence to let expected inflation to be constant
                    obj.pt(t+1:end,k-1) = obj.pt(t+1:end,k-1)/obj.pt(t,k-1)*obj.pt(t,k);
                    obj.p_upstrt(t+1:end,k-1) = obj.p_upstrt(t+1:end,k-1)/obj.pt(t,k-1)*obj.pt(t,k);
                    
                    eq_prices = @(x) obj.equilibrium_conditions(t, k, distr, x(1), x(2)); % x(1) is w and x(2) is p
                    
                    res = fsolve(eq_prices, [obj.wt(t,k-1)/obj.pt(t,k-1), obj.Nt(t,k-1)], optimoptions('fsolve','disp','none')); %,'UseParallel',true));
                    
                    obj.wt(t,k) = res(1)*obj.pt(t,k);
                    obj.Nt(t,k) = res(2);
                    
                    % Upstream good price ends up being
                    obj.p_upstrt(t,k) = obj.wt(t,k)/(1-obj.delta)*obj.Nt(t,k)^obj.delta;
                                                            
                    cumR = cumprod(obj.Rt(t:end));
                    obj.qt(t,k) = obj.delta*(obj.Yt(t+1:end,k-1)'*(obj.pt(t+1:end,k-1)./cumR(1:end-1)) + obj.Y*obj.pt(end,k-1)*obj.R/(cumR(end)*(obj.R-1)))/obj.pt(t,k);
                    
                    % Set the corresponding policy functions for time t level k
                    [obj.ct{t,k}, obj.nt{t,k}, obj.at{t,k}] = obj.get_policy_today(t, k, obj.Nt(t,k), obj.wt(t,k)/obj.pt(t,k));
                    
                    % Simulate next period distribution
                    distr = obj.simulate_economy_forward(t, k, distr);
                end
                obj.pt(T,k) = obj.pt(T-1,k);
                obj.wt(T,k) = obj.w*obj.pt(T,k);
                obj.p_upstrt(T,k) = obj.p_upstr*obj.pt(T,k);
                obj.pt_realized(:,k) = obj.pt(:,k);
                obj.wt_realized(:,k) = obj.wt(:,k);
                obj.p_upstrt_realized(:,k) = obj.p_upstrt(:,k);
            end
            % Save memory
            obj.ct(:)=[];
            obj.nt(:)=[];
            obj.at(:)=[];
        end
        
    end

    
    methods(Access = public)
        function res = equilibrium_conditions(obj, t, k, distr, w_real, Nt_guess)
            [c_pol, n_pol, ~] = obj.get_policy_today(t, k, Nt_guess, w_real);
            
            cy = sum(sum(c_pol.*distr)) - Nt_guess^(1-obj.delta)/obj.Deltat(t,k);
            nN = sum(sum(n_pol.*distr)) - Nt_guess;
            res(1)=cy;
            res(2)=nN;
        end
        
        function [c_pol_new, n_pol_new, a_pol_new] = get_policy_today(obj, t, k, N_loop, w_real)
            %% Calculate policy function today given policy function tomorrow
            % Calculate relevant prices for consumption problem
            Delta = obj.Deltat(t,k);
            c_pol = obj.ct{t+1,k};
            Y_loop = N_loop^(1-obj.delta)/Delta;
            R_real_loop = obj.Rt(t)/obj.pt(t+1,k-1)*obj.pt(t,k);
            cumR = cumprod(obj.Rt(t:end));
            revaluation=0;
            if t==1
                Y_exp = obj.Y;
                R_last = obj.R;
                q_exp = obj.q;
                % Initial revaluation factor
                q_new = obj.delta*(obj.Yt(t+1:end,k-1)'*(obj.pt(t+1:end,k-1)./cumR(1:end-1)) + obj.Y*obj.pt(end,k-1)*obj.R/(cumR(end)*(obj.R-1)))/obj.pt(t,k);
                revaluation = 1/R_last/obj.q*(q_new-q_exp);
                % Dividend update factor
                dividend = 1/R_last/obj.q*obj.delta*(Y_loop-Y_exp);
            else
                % No revaluation when no expectations change
                Y_exp = obj.Yt(t,k-1);
                R_last = obj.Rt(t-1)/obj.pt(t,k-1)*obj.pt(t-1,k);
                % Dividend update factor
                dividend = 1/R_last/obj.qt(t-1,k)*obj.delta*(Y_loop-Y_exp);
                if t>obj.t_change
                    % Revaluation when inflation today turns out to be
                    % different than expected so Taylor Rule updates R_t
                    R_exp = obj.R*(obj.pt_realized(t,k-1)./obj.pt_realized(t-1,k-1)).^obj.taylor_inf;
                    R_exp = [R_exp; obj.Rt(t+1:end)];
                    cumR_exp = cumprod(R_exp);
                    q_exp = obj.delta*(obj.Yt(t+1:end,k-1)'*(obj.pt(t+1:end,k-1)./cumR_exp(1:end-1)) + obj.Y*obj.pt(end,k-1)*obj.R/(cumR_exp(end)*(obj.R-1)))/obj.pt(t,k);
                    % Revaluation factor
                    q_new = obj.delta*(obj.Yt(t+1:end,k-1)'*(obj.pt(t+1:end,k-1)./cumR(1:end-1)) + obj.Y*obj.pt(end,k-1)*obj.R/(cumR(end)*(obj.R-1)))/obj.pt(t,k);
                    revaluation = 1/R_last/obj.qt(t-1,k)*(q_new-q_exp);
                end
            end
            
            c_pol_new         = zeros(obj.agrid,obj.Nstate);
            a_pol_new         = zeros(obj.agrid,obj.Nstate);
            n_pol_new         = zeros(obj.agrid,obj.Nstate);
            
            parfor state=1:obj.Nstate
                c_assoc_loop=zeros(obj.agrid,1);
                a_assoc_loop=zeros(obj.agrid,1);
                
                for i=1:obj.agrid    %parfor
                    
                    marguprime_loop = uprime(c_pol(i,:),obj.calibration)*obj.pp(:,state);
                    
                    c_assoc_loop(i) = uprimeinv(obj.beta*R_real_loop*marguprime_loop,obj.calibration);
                    n = (c_assoc_loop(i)^-obj.sigma*obj.ss(state)^(1+obj.gamma)*w_real)^(1/obj.gamma);
                    a_assoc_loop(i) = (obj.a(i)/R_real_loop + c_assoc_loop(i) - n/N_loop*(1-obj.delta)*Y_loop )/(1 + dividend + revaluation);
                end
                
                ind_low=(obj.a<a_assoc_loop(1));
                ind_high=(obj.a>a_assoc_loop(end));
                
                c_extrap_low          = zeros(obj.agrid,1);
                for i = 1:sum(ind_low)
                    try
                        c_extrap_low(i) = fzero(@(c) obj.a(i)*(1 + dividend + revaluation) + (c^-obj.sigma*(obj.ss(state)^(1+obj.gamma))*w_real)^(1/obj.gamma)/N_loop*(1-obj.delta)*Y_loop - obj.alow/R_real_loop - c, [0.0001, c_assoc_loop(1)]);
                    catch
%                         disp(num2str([w p state]));
%                         c_extrap_low(i)=Inf;
%                         c_assoc_loop = Inf;
                    end
                end
                c_extrap_high         = exp(log(c_assoc_loop(end))+(log(c_assoc_loop(end))-log(c_assoc_loop(end-1)))/(a_assoc_loop(end)-a_assoc_loop(end-1))*(obj.a-a_assoc_loop(end)));
                
                try
                    c_pol_new(:,state)         = (ones(obj.agrid,1)-ind_low).*(ones(obj.agrid,1)-ind_high).*interp1(a_assoc_loop,c_assoc_loop,obj.a,'linear','extrap')+c_extrap_low.*ind_low+c_extrap_high.*ind_high;
                catch
                    c_pol_new(:,state) = NaN;
                end
                
                n_pol_new(:,state)         = (c_pol_new(:,state).^-obj.sigma.*(obj.ss(state)^(1+obj.gamma))*w_real).^(1/obj.gamma);
                a_pol_new(:,state)         = (obj.a*(1 + dividend + revaluation) + n_pol_new(:,state)/N_loop*(1-obj.delta)*Y_loop - c_pol_new(:,state))*R_real_loop;
            end
        end
                
        function new_distr = simulate_economy_forward(obj, t, k, distr)
            %% Returns next period distribution using savings policy function for time t level k
			h = (obj.a_NL(end)-obj.a_NL(1))/(obj.agrid-1);
            indices_a_next = zeros(2*obj.agrid*obj.Nstate^2,1);
            p_values  = zeros(2*obj.agrid*obj.Nstate^2,1);
            
            for i=1:obj.Nstate
				a_next = min(max(floor((((obj.at{t,k}(:,i) - obj.a(1))/(obj.a(end)-obj.a(1))).^(1/obj.grid_NL))/h)' + 1, 1), obj.agrid-1);
				u = max(min(((obj.at{t,k}(:,i) - obj.a(a_next))./(obj.a(a_next+1) - obj.a(a_next))), 1), 0)';
                u = repmat(u, obj.Nstate,1);
                u = [(1-u(:))'; u(:)'];
                a_next = repmat(a_next,obj.Nstate,1) + obj.agrid*repmat((0:(obj.Nstate-1))',1,obj.agrid);
                a_next = [a_next(:)'; a_next(:)'+1]; % mix them in alternating order once (:) is applied
                indices_a_next(((i-1)*2*obj.agrid*obj.Nstate+1):i*2*obj.agrid*obj.Nstate) = a_next(:);
                p_values(((i-1)*2*obj.agrid*obj.Nstate+1):i*2*obj.agrid*obj.Nstate) = u(:).*kron(repmat(obj.pp(:,i),obj.agrid,1), ones(2,1));
            end
            
            x_trans = sparse(kron(1:(obj.agrid*obj.Nstate),ones(1,2*obj.Nstate)),indices_a_next,p_values,obj.agrid*obj.Nstate,obj.agrid*obj.Nstate);
            
            new_distr = x_trans'*distr(:);
            new_distr = reshape(new_distr,obj.agrid,obj.Nstate);
        end
    end
end