import torch
from torch.autograd.functional import jacobian
from torch.nn.functional import relu
from typing import List

from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.environments.environment import Environment
from policy_selection_for_inventories.environments.finite_warehouse_pre import Finite_Warehouse_Pre
from policy_selection_for_inventories.environments.product import Product

class GAPSI_Special_Constrained(Algorithm) :
    def __init__(self, environment:Environment, param_dim: List[int], learning_rate:float, buffer_size:int, initial_parameter:torch.Tensor, min_value:torch.Tensor, max_value:torch.Tensor, features: List[torch.Tensor]) :
        assert isinstance(environment, Finite_Warehouse_Pre), "This class only handles problems with warhouse-capacity constraints"
        self.environment = environment
        self.learning_rate = learning_rate
        self.buffer_size = buffer_size
        self.initial_parameter = initial_parameter
        self.min_value = min_value
        self.max_value = max_value
        self.features = features
        self.param_dim = param_dim
        self.reset()

    def reset(self) :
        self.buffer = [torch.zeros(sum([product.state_dim for product in self.environment.products_list]),sum(self.param_dim)) for _ in range(self.buffer_size)]
        self.sum_of_squared_gradients = torch.zeros(sum(self.param_dim))

        self.old_parameter = None
        self.parameter = None
        self.old_state = None
        self.state = None
        self.old_control = None
        self.control = None
        self.d_control_state = None
        self.d_control_parameter = None

        self.last_thresholded_state_control = None
        self.last_auxiliary_computation_period = None
        self.last_auxiliary_derivative = [
        #[[
            None
        #for j in range(product_varying.state_dim+1)]
        #for k_prime, product_varying in enumerate(self.environment.products_list)]
        for k in range(len(self.environment.products_list))]

    def policy(self, t:int, state:torch.Tensor, parameter:torch.Tensor) -> torch.Tensor :
        state_idx = 0
        parameter_idx = 0
        control = torch.zeros(len(self.environment.products_list))
        for product_idx, product in enumerate(self.environment.products_list) :
            diff = torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[product_idx]], self.features[product_idx][t])-state[state_idx:state_idx+product.state_dim].sum()
            control[product_idx] = torch.where(diff>=0.0, diff, 0.0)
            state_idx += product.state_dim
            parameter_idx += self.param_dim[product_idx]
        return control
    

    def compute_state_control_and_its_volume(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        state_control = torch.zeros(0)
        volume = 0.0
        for p_idx, product in enumerate(self.environment.products_list) :
            state_control = torch.cat([state_control, state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim], control[p_idx:p_idx+1]])
            volume += self.environment.volumes[p_idx]*state_control[self.environment.state_control_indexes[p_idx] : self.environment.state_control_indexes[p_idx] + product.lifetime].sum()
        return state_control, volume

    def compute_thresholded_stuff_if_necessary(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        # We recompute the auxiliary derivative only for a new time step.
        if( t != self.last_auxiliary_computation_period) :
            # Computing next state's volume and finding indexes of states and controls
            state_control, volume = self.compute_state_control_and_its_volume(t, state, control)
            self.last_thresholded_state_control = state_control.clone()
            # if(volume > self.environment.total_volume) :
            #     print("t={}, Overflow occured".format(t))

            for k, product in enumerate(self.environment.products_list) :
                # Computing tmp = f_k^3
                if(k==0) :
                    if(volume > self.environment.total_volume) :
                        tmp = volume - self.environment.total_volume
                    else :
                        tmp = 0.0
                else :
                    tmp = tmp - self.environment.volumes[k-1]*state_control[self.environment.state_control_indexes[k-1]+self.environment.products_list[k-1].lifetime-1]

                # Computing tmp_2 = f_k^5
                if(tmp > 0.0) :
                    tmp_2 = state_control[self.environment.state_control_indexes[k]+product.lifetime -1] - tmp/self.environment.volumes[k]
                else :
                    tmp_2 = state_control[self.environment.state_control_indexes[k]+product.lifetime -1]

                # Computing \tilde{z}_{t,k,m_k} and memorizing it
                if(tmp_2 > 0.0) :
                    self.last_thresholded_state_control[self.environment.state_control_indexes[k]+product.lifetime -1] = tmp_2
                else :
                    self.last_thresholded_state_control[self.environment.state_control_indexes[k]+product.lifetime -1] = 0.0

                deriv = torch.zeros_like(state_control)
                if(volume >= self.environment.total_volume) :
                    deriv[self.environment.all_on_hand_sc_indexes] += 1.0
                deriv[self.environment.all_just_received_sc_indexes[:k]] -= 1.0
                deriv *= self.environment.all_volumes_as_sc

                deriv = torch.where((tmp>0.0) | ((tmp==0.0) & (deriv>0.0)), deriv, 0.0)
                deriv = -deriv/self.environment.volumes[k]
                deriv[self.environment.all_just_received_sc_indexes[k]] += 1.0
                deriv = torch.where((tmp_2>0.0) | ((tmp_2==0.0) & (deriv>0.0)), deriv, 0.0)

                self.last_auxiliary_derivative[k] = deriv

                # test = torch.zeros_like(state_control)
                # test[self.environment.all_just_received_sc_indexes[k]] = 1.0
                # if((self.last_auxiliary_derivative[k] != test).any()) :
                #     print("t={}, k={}, dz_tilde_k is not trivial.".format(t, k))
            self.last_auxiliary_computation_period = t



    def cost_state_partial_derivative_term(self, t:int, k:int, product:Product, z_tilde_0:float, sum_z_tilde_m:float) :
        """
        Implements the right derivative
        k, k_prime and j are indexes starting from 0
        """

        #sum([product_prime.state_dim for product_prime in self.environment.products_list]))
        result = torch.zeros(len(self.environment.all_state_indexes_of_sc))
        m_k = product.lifetime
        demand = product.demands[t]
        dz_tilde = self.last_auxiliary_derivative[k][self.environment.all_state_indexes_of_sc]

        # Overflow cost
        result -= self.environment.overflow_cost[k][t]*dz_tilde
        if(m_k-1 < product.state_dim) :
            result[self.environment.state_indexes[k]+m_k-1] += self.environment.overflow_cost[k][t]

        # Outdating cost
        if(m_k == 1) :
            result += torch.where((z_tilde_0>demand) | ((z_tilde_0==demand) & (dz_tilde > 0.0)), product.outdating_cost[t]*dz_tilde, 0.0)
        else :
            if(z_tilde_0 >= demand) :
                result[self.environment.state_indexes[k]] += product.outdating_cost[t]

        tmp = dz_tilde.clone()
        tmp[self.environment.state_indexes[k]:self.environment.state_indexes[k]+m_k-1] += 1.0
        # Holding_cost
        result += torch.where((sum_z_tilde_m>demand) | ((sum_z_tilde_m==demand) & (tmp > 0.0)), product.holding_cost[t]*tmp, 0.0)
        # Penalty cost
        result -= torch.where((sum_z_tilde_m<demand) | ((sum_z_tilde_m==demand) & (tmp<0.0)), product.penalty_cost[t]*tmp, 0.0)
        return result

    def cost_control_partial_derivative_term(self, t:int, k:int, product:Product, z_tilde_0:float, sum_z_tilde_m:float) :
        """
        Implements the right derivative
        k, k_prime and j are indexes starting from 0
        """
        result = torch.zeros(len(self.environment.products_list))
        m_k = product.lifetime
        L_k = product.leadtime
        demand = product.demands[t]
        dz_tilde = self.last_auxiliary_derivative[k][self.environment.all_control_indexes_of_sc]

        # Purchase cost
        result[k] += product.purchase_cost[t]

        # Overflow cost
        result -= self.environment.overflow_cost[k][t]*dz_tilde
        if(L_k == 0) :
            result[k] += self.environment.overflow_cost[k][t]

        # Outdating cost
        if(m_k == 1) :
            result += torch.where((z_tilde_0>demand) | ((z_tilde_0==demand) & (dz_tilde > 0.0)), product.outdating_cost[t]*dz_tilde, 0.0)

        # Holding_cost
        result += torch.where((sum_z_tilde_m>demand) | ((sum_z_tilde_m==demand) & (dz_tilde > 0.0)), product.holding_cost[t]*dz_tilde, 0.0)

        # Penalty cost
        result -= torch.where((sum_z_tilde_m<demand) | ((sum_z_tilde_m==demand) & (dz_tilde<0.0)), product.penalty_cost[t]*dz_tilde, 0.0)

        return result


    def derivative_cost_state(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the right derivative
        """
        # return jacobian(lambda input : self.environment.cost(t, input, control), state)
        self.compute_thresholded_stuff_if_necessary(t,state, control)
        computed_jacobian = torch.zeros(len(state))
        #assert (jacobian(lambda input : self.environment.cost(t, input, control), state).shape == computed_jacobian.shape), "Error in shape c/s"
        
        for k, product in enumerate(self.environment.products_list) :
            z_tilde_0 = self.last_thresholded_state_control[self.environment.state_control_indexes[k]]
            sum_z_tilde_m = self.last_thresholded_state_control[self.environment.state_control_indexes[k]:self.environment.state_control_indexes[k]+product.lifetime].sum()
            computed_jacobian += self.cost_state_partial_derivative_term(t, k, product, z_tilde_0, sum_z_tilde_m)
        # if((jacobian(lambda input : self.environment.cost(t, input, control), state) - computed_jacobian).abs().sum() != 0):
        #     print("Pytorch differ from custom at cost/state for t =", t)
        return computed_jacobian

    def derivative_cost_control(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the right derivative
        """
        # return jacobian(lambda input : self.environment.cost(t, state, input), control)
        self.compute_thresholded_stuff_if_necessary(t,state, control)
        computed_jacobian = torch.zeros(len(control))
        #assert (jacobian(lambda input : self.environment.cost(t, state, input), control).shape == computed_jacobian.shape), "Error in shape c/c"

        for k, product in enumerate(self.environment.products_list) :
            z_tilde_0 = self.last_thresholded_state_control[self.environment.state_control_indexes[k]]
            sum_z_tilde_m = self.last_thresholded_state_control[self.environment.state_control_indexes[k]:self.environment.state_control_indexes[k]+product.lifetime].sum()
            computed_jacobian += self.cost_control_partial_derivative_term(t, k, product, z_tilde_0, sum_z_tilde_m)

        # pytorch_jacobian = jacobian(lambda input : self.environment.cost(t, state, input), control)
        # if((pytorch_jacobian - computed_jacobian).abs().sum() != 0):
        #     print("Pytorch differ from custom at cost/control t =", t)
        return computed_jacobian


    def derivative_transition_state(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the right derivative
        """
        self.compute_thresholded_stuff_if_necessary(t, state, control)
        computed_jacobian = torch.zeros(len(state),len(state))
        #assert (jacobian(lambda input : self.environment.transition(t, input, control), state).shape == computed_jacobian.shape), "Error in shape t/s"

        for p_idx, product in enumerate(self.environment.products_list) :
            m = product.lifetime
            L = product.leadtime
            demand = product.demands[t]
            
            state_control = torch.cat([
                state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim],
                control[p_idx:p_idx+1]
            ])
            
            for i in range(m-1, m+L-2) :
                computed_jacobian[
                    self.environment.state_indexes[p_idx]+i,
                    self.environment.state_indexes[p_idx]+i+1
                ] = 1.0
            for i in range(m-2) :
                if(state_control[i+1] >= relu(demand-state_control[:i+1].sum())) :
                    computed_jacobian[
                        self.environment.state_indexes[p_idx]+i,
                        self.environment.state_indexes[p_idx]+i+1
                    ] = 1.0
                if(state_control[i+1] >= demand-state_control[:i+1].sum()) and (demand > state_control[:i+1].sum()) :
                    computed_jacobian[
                        self.environment.state_indexes[p_idx]+i,
                        self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+i+1
                    ] = 1.0
            # Handling i = m-2
            if(m >= 2) :
                #   Computing f_k^7
                value = self.last_thresholded_state_control[self.environment.state_control_indexes[p_idx]+m-1]
                value = value - relu(demand - state_control[:m-1].sum())
                #   Computing df_k^7
                deriv = self.last_auxiliary_derivative[p_idx][self.environment.all_state_indexes_of_sc].clone()
                if(demand > state_control[:m-1].sum()) :
                    deriv[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+m-1] += 1.0
                computed_jacobian[self.environment.state_indexes[p_idx]+m-2] = torch.where((value>0.0) | ((value==0.0) & (deriv>0.0)), deriv, 0.0)   
        # pytorch_jac = jacobian(lambda input : self.environment.transition(t, input, control), state)
        # if((pytorch_jac-computed_jacobian).abs().sum() !=0) :
        #     print("Difference at t=",t, "for transition/state")
        #     print("Pytorhc\n", pytorch_jac)
        #     print("Custom\n", computed_jacobian)
        return computed_jacobian

    def derivative_transition_control(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the right derivative
        """
        self.compute_thresholded_stuff_if_necessary(t, state, control)
        computed_jacobian = torch.zeros(len(state),len(control))
        #assert (jacobian(lambda input : self.environment.transition(t, state, input), control).shape == computed_jacobian.shape), "Error in shape t/c"

        for p_idx, product in enumerate(self.environment.products_list) :
            m = product.lifetime
            L = product.leadtime
            demand = product.demands[t]
            
            state_control = torch.cat([
                state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim],
                control[p_idx:p_idx+1]
            ])

            if(L > 0) :
                computed_jacobian[
                    self.environment.state_indexes[p_idx]+m+L-2,
                    p_idx
                ] = 1.0
            elif(m >= 2) :
                #   Computing f_k^7
                value = self.last_thresholded_state_control[self.environment.state_control_indexes[p_idx]+m-1]
                value = value - relu(demand - state_control[:m-1].sum())
                #   Computing df_k^7
                deriv = self.last_auxiliary_derivative[p_idx][self.environment.all_control_indexes_of_sc].clone()
                computed_jacobian[self.environment.state_indexes[p_idx]+m-2] = torch.where((value>0.0) | ((value==0.0) & (deriv>0.0)), deriv, 0.0)   
        # pytorch_jac = jacobian(lambda input : self.environment.transition(t, state, input), control)
        # if((pytorch_jac - computed_jacobian).abs().sum() !=0) :
        #     print("Difference at t=",t, "for transition/control")
        #     print("Pytorch\n", pytorch_jac)
        #     print("Custom\n", computed_jacobian)
        return computed_jacobian
    
    
    def derivative_policy_state(self, t:int, state:torch.Tensor, parameter:torch.Tensor) :
        """
        Implements the right derivative
        """
        computed_jacobian = torch.zeros(len(self.environment.products_list), len(state))
        #assert (jacobian(lambda input : self.policy(t, input, parameter), state).shape == computed_jacobian.shape), "Error in shape p/s"

        parameter_idx = 0
        for p_idx, product in enumerate(self.environment.products_list) :
            computed_jacobian[p_idx, self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim] = -float(
                torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[p_idx]], self.features[p_idx][t])
                > state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim].sum()
            )
            parameter_idx += self.param_dim[p_idx]
        return computed_jacobian

    def derivative_policy_parameter(self, t:int, state:torch.Tensor, parameter:torch.Tensor) :
        """
        Implements the right derivative
        """
        computed_jacobian = torch.zeros(len(self.environment.products_list), len(parameter))
        #assert (jacobian(lambda input : self.policy(t, state, input), parameter).shape == computed_jacobian.shape), "Error in shape p/p"

        parameter_idx = 0
        for p_idx, product in enumerate(self.environment.products_list) :
            value = (
                torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[p_idx]], self.features[p_idx][t])
                - state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim].sum()
            )
            computed_jacobian[p_idx,parameter_idx:parameter_idx+self.param_dim[p_idx]] = torch.where((value>0.0) | ((value==0.0) & (self.features[p_idx][t]>0.0)), self.features[p_idx][t], 0.0)
            parameter_idx += self.param_dim[p_idx]
        return computed_jacobian
    
    def get_control(self, t:int, state:torch.Tensor) -> torch.Tensor :
        """
        Updates the parameter and outputs the corresponding control
        """
        if(t>=1) :
            if(t>=2) :
                d_transition_control = self.derivative_transition_control(t-2, self.old_state, self.old_control)

                d_state_state_old = self.d_control_state
                d_state_state_old = torch.matmul(d_transition_control, d_state_state_old)
                d_state_state_old += self.derivative_transition_state(t-2, self.old_state, self.old_control)
    
            for b in reversed(range(2, self.buffer_size)) :
                if(t>=b+1) :
                    self.buffer[b] = torch.matmul(d_state_state_old, self.buffer[b-1])

            if(t>=2) :
                self.buffer[1] = self.d_control_parameter
                self.buffer[1] = torch.matmul(d_transition_control, self.buffer[1])

            self.d_control_state = self.derivative_policy_state(t-1, self.state, self.parameter)
            self.d_control_parameter = self.derivative_policy_parameter(t-1, self.state, self.parameter)
            d_cost_control = self.derivative_cost_control(t-1, self.state, self.control)

            gradient = self.derivative_cost_state(t-1, self.state, self.control)
            gradient += torch.matmul(d_cost_control, self.d_control_state)
            gradient = torch.matmul(gradient, sum(self.buffer[1:]))
            gradient += torch.matmul(d_cost_control, self.d_control_parameter)

            # print("t={}, last_gradient_term[3]=\n{}".format(t,torch.matmul(d_cost_control, self.d_control_parameter)[3]))
            # print("t={}, gradient[3]=\n{}".format(t,gradient[3]))
            self.sum_of_squared_gradients += gradient**2
            #self.sum_of_squared_gradients = torch.maximum(self.sum_of_squared_gradients, torch.abs(gradient))
            #self.sum_of_squared_gradients = torch.ones_like(gradient)
            diameter = self.max_value-self.min_value

            inverse_of_RSSG = 1/torch.sqrt(self.sum_of_squared_gradients)
            inverse_of_RSSG[inverse_of_RSSG == float("inf")] = 0.0

            self.old_parameter = self.parameter
            self.parameter = torch.clip(self.parameter - self.learning_rate*diameter*inverse_of_RSSG*gradient, self.min_value, self.max_value)

        else :
            self.old_parameter = self.parameter
            self.parameter = self.initial_parameter

        self.old_state = self.state
        self.state = state

        self.old_control = self.control
        self.control = self.policy(t, state, self.parameter)
        return self.control