classdef SteadyStateReducedForm < handle
    
    properties (GetAccess = public, SetAccess = public)
        calibration         % the calibration being used
        Nstate              % # states markov process
        ss                  % state values Markov Process
        pp                  % transition probabilities MP
        ppcum               % cumulative transition probabilities MP
        ssmean              % mean value MP
        sigma               % risk aversion
        delta               % capital share
        beta = NaN;	        % discount factor
		beta_high			% beta ceiling to find steady state
        R                   % nominal SS interest rate
		R_high				% R ceiling to find steady state
        alow                % borrowing constraint limit
        a                   % asset grid
		a_NL                % Non-linear grid
        grid_NL             % measures grid non-linearity (linear case = 1)
        agrid               % asset grid size
        cc                  % consumption policy function
        aa                  % savings policy function
        q                   % asset price
        frac_consMC         % fraction of borrowing constrained calculated using MC
        
        invariant
    end
    
    properties (Access=protected)   % play with those if things don't work...
        verbose   = false;          % switch to enable debug output
        tol = 1e-6;                 % max relative deviation for convergence
        weight_new = 1; 
    end
    
    properties (Dependent=true, SetAccess=protected)% play with those if things don't work...
        Y                   % Aggregate Output
        C                   % Aggregate Consumption
        A                   % Aggregate Asset Savings
        frac_cons           % fraction of borrowing constrained
    end
    
    methods
        
        function self = SteadyStateReducedForm(calibr)
            % Constructor. pass it a calibration structure
            self.calibration = calibr;
            self.Nstate = calibr.Nstate; 
            self.pp = calibr.pp;
            self.ppcum = calibr.ppcum;
            self.ss = calibr.ss;
            self.ssmean = calibr.ssmean;
            self.sigma = calibr.sigma;
            self.delta = calibr.delta;
			if isfield(calibr, 'beta')
				self.beta = calibr.beta;
			end
			self.beta_high = calibr.beta_high;
            self.R = calibr.R;
			if isfield(calibr, 'R_high')
				self.R_high = calibr.R_high;
			end
            self.alow = calibr.alow;
            self.a = calibr.a;
			self.a_NL = calibr.a_NL;
            self.grid_NL = calibr.grid_NL;
            self.agrid = calibr.agrid;

			self.cc = zeros(self.agrid, self.Nstate);
            self.aa = zeros(self.agrid, self.Nstate);
            
            self.calculate_steady_state();
		end
        
        function res = get.Y(self)
            res = self.ssmean;
        end
        
        function res = get.C(self)
            res = sum(sum(self.cc.*self.invariant));
        end
        
        function res = get.A(self)
            res = sum(sum(self.aa.*self.invariant));
        end
        
        function res = get.frac_cons(self)
            a_assoc = ones(1,self.Nstate);
            for state=1:self.Nstate
                marguprime_loop = uprime(self.cc(1,:),self.calibration)*self.pp(:,state);
                c_assoc = uprimeinv(self.beta*self.R*marguprime_loop,self.calibration);
                a_assoc(state) = self.alow/self.R + c_assoc - self.ss(state)*(1-self.delta)*self.Y;
            end
            bc = repmat(self.a,1,self.Nstate) < repmat(a_assoc, self.agrid, 1); % identify points in the grid that are BC
            res = sum(sum(self.invariant(bc)));
        end
        
        function cons = setfrac_consMC(self, S, T)
            self.aa = max(self.aa,0);
            cross = zeros(S,2);
            cross(:,2) = randi(7,S,1);
            a_assoc = ones(self.Nstate,1);
			for state=1:self.Nstate
				marguprime_loop = uprime(self.cc(1,:),self.calibration)*self.pp(:,state);
				c_assoc = uprimeinv(self.beta*self.R*marguprime_loop,self.calibration);
				a_assoc(state) = self.alow/self.R + c_assoc - self.ss(state)*(1-self.delta)*self.Y;
			end
			cons = zeros(T,1);
			
			for n=1:T
				draw = rand(S,1);
				next_y = ones(S,1);
				for state=1:self.Nstate
					mask = cross(:,2)==state;
					cons(n) = cons(n) + sum(cross(mask,1)<a_assoc(state))/S;
					cross(mask, 1) = interp1(self.a, max(self.aa(:,state),0), cross(mask, 1));
					
					[~, next_y(mask)] = histc(draw(mask),[0;self.ppcum(:,state)]);
				end
				cross(:,2) = next_y;
			end
			self.frac_consMC = mean(cons(end-50:end));
		end
		
	end
    
    methods(Access = public)
        
        function calculate_steady_state(self)
            % Calculate the steady state
            self.cc = self.a*ones(1,self.Nstate) + 0.1;
            self.aa = self.a*ones(1,self.Nstate);
            
            % Outer loop to get beta consistent with good markets clearing C=Y
			options = optimset('TolX',1e-8, 'Display', 'iter');

            % One-to-one relation between R and beta, find one or the other
			if isnan(self.beta)
				fun = @(x) self.trybeta(x);
				beta_opt = fzero(fun, [0.80 self.beta_high], options);
				self.trybeta(beta_opt);
			else
				fun = @(x) self.tryR(x);
				R_ss = fzero(fun, [0.9 self.R_high], options);
				self.tryR(R_ss);
			end

            self.q = self.delta*self.Y/(self.R-1);
        end
        
        function calculate_policy_function(self)
            %% Calculate optimal consumption function using Carroll (2006)
            % initialization
            
            c_pol = self.cc;
            a_pol = self.aa;
            
            iter=0; crit_c=1;
            % policy function iteration with Perri trick
            while crit_c > self.tol %&& iter < 300;
                [c_pol_new, a_pol_new] = self.get_next_policy(c_pol);              
                crit_c = max(max(abs(c_pol-c_pol_new)));

                c_pol = c_pol_new;
                a_pol = a_pol_new;
                iter = iter+1;
            end
            self.cc = c_pol;
            self.aa = a_pol;
        end
        
        function [c_pol_new, a_pol_new] = get_next_policy(self, c_pol)
            %% Calculate policy function today given policy function tomorrow
            % Calculate relevant prices for consumption problem
            Y_loop = self.ssmean;
            
            c_pol_new         = zeros(self.agrid,self.Nstate);
            a_pol_new         = zeros(self.agrid,self.Nstate);
            
            % Avoid sending the whole object to the parallel pool
            agrid_loop = self.agrid;
            calibration_loop = self.calibration;
            beta_loop = self.beta;
            R_loop = self.R;
            ss_loop = self.ss;
            alow_loop = self.alow;
            a_loop = self.a;
			delta_loop = self.delta;
            Nstate_loop = self.Nstate;
            
            parfor state=1:Nstate_loop
                c_assoc_loop=zeros(agrid_loop,1);
                a_assoc_loop=zeros(agrid_loop,1);
                
                for i=1:agrid_loop    %parfor
                    
                    marguprime_loop = uprime(c_pol(i,:),calibration_loop)*calibration_loop.pp(:,state);
                    
                    c_assoc_loop(i) = uprimeinv(beta_loop*R_loop*marguprime_loop,calibration_loop);
                    a_assoc_loop(i) = a_loop(i)/R_loop + c_assoc_loop(i) - ss_loop(state)*(1-delta_loop)*Y_loop;
                    
                end
                
                ind_low=(a_loop<a_assoc_loop(1));
                ind_high=(a_loop>a_assoc_loop(end));
                
                c_extrap_low          = zeros(agrid_loop,1);
                for i = 1:sum(ind_low)
                    c_extrap_low(i) = fzero(@(c) a_loop(i) + ss_loop(state)*(1-delta_loop)*Y_loop - alow_loop/R_loop - c, [0.0001, c_assoc_loop(1)]);
                end
                c_extrap_high         = exp(log(c_assoc_loop(end))+(log(c_assoc_loop(end))-log(c_assoc_loop(end-1)))/(a_assoc_loop(end)-a_assoc_loop(end-1))*(a_loop-a_assoc_loop(end)));
                
                c_pol_new(:,state)         = (ones(agrid_loop,1)-ind_low).*(ones(agrid_loop,1)-ind_high).*interp1(a_assoc_loop,c_assoc_loop,a_loop,'linear','extrap')+c_extrap_low.*ind_low+c_extrap_high.*ind_high;
                a_pol_new(:,state)         = (a_loop + ss_loop(state)*(1-delta_loop)*Y_loop - c_pol_new(:,state))*R_loop;
            end
        end
        
        function calculate_invariant(self)
            %% Calculate induced Markov chain on cash-in-hand
            h = (self.a_NL(end)-self.a_NL(1))/(self.agrid-1);
            indices_a_next = zeros(2*self.agrid*self.Nstate^2,1);
            p_values  = zeros(2*self.agrid*self.Nstate^2,1);
            
            for i=1:self.Nstate
				a_next = min(max(floor((((self.aa(:,i) - self.a(1))/(self.a(end)-self.a(1))).^(1/self.grid_NL))/h)' + 1, 1), self.agrid-1);
				u = max(min(((self.aa(:,i) - self.a(a_next))./(self.a(a_next+1) - self.a(a_next))), 1), 0)';
                u = repmat(u, self.Nstate,1);
                u = [(1-u(:))'; u(:)'];
                a_next = repmat(a_next,self.Nstate,1) + self.agrid*repmat((0:(self.Nstate-1))',1,self.agrid);
                a_next = [a_next(:)'; a_next(:)'+1]; % mix them in alternating order once (:) is applied
                indices_a_next(((i-1)*2*self.agrid*self.Nstate+1):i*2*self.agrid*self.Nstate) = a_next(:);
                p_values(((i-1)*2*self.agrid*self.Nstate+1):i*2*self.agrid*self.Nstate) = u(:).*kron(repmat(self.pp(:,i),self.agrid,1), ones(2,1));
            end
            
            x_trans = sparse(kron(1:(self.agrid*self.Nstate),ones(1,2*self.Nstate)),indices_a_next,p_values,self.agrid*self.Nstate,self.agrid*self.Nstate);
            
            invar = zeros(self.agrid*self.Nstate,1);
            invar(1)=1;
            invar = invar/sum(invar);
            while max(max(abs(invar'*x_trans - invar'))) > 1e-15
                invar = x_trans'*invar;
            end
            
            invar = reshape(invar,self.agrid,self.Nstate);
            
            self.invariant = invar;
		end
		
		function res = trybeta(self, beta)
			self.beta = beta;

			tic;
			self.calculate_policy_function();
			self.calculate_invariant();
			res = self.C-self.Y;
			disp(['--- Outer loop for beta=',num2str(self.beta,'%.8f'),' ---']);
			toc;
			fprintf(['C=',num2str(self.C), ', Y=', num2str(self.Y),', C-Y=', num2str(self.C-self.Y),'\n\n']);
		end
		
		function res = tryR(self, R)
			self.R = R;

			tic;
			self.calculate_policy_function();
			self.calculate_invariant();
			res = self.C-self.Y;
			disp(['--- Outer loop for R=',num2str(R,'%.8f'),' ---']);
			toc;
			fprintf(['C=',num2str(self.C), ', Y=', num2str(self.Y),', C-Y=', num2str(self.C-self.Y),'\n\n']);
		end
		
    end
    
end