import torch
from torch.autograd.functional import jacobian
from torch.nn.functional import relu
from typing import List

from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.environments.environment import Environment
from policy_selection_for_inventories.environments.infinite_warehouse import Infinite_Warehouse

class GAPSI_Special(Algorithm) :
    def __init__(self, environment:Environment, param_dim: List[int], learning_rate:float, buffer_size:int, initial_parameter:torch.Tensor, min_value:torch.Tensor, max_value:torch.Tensor, features: List[torch.Tensor]) :
        assert isinstance(environment, Infinite_Warehouse), "This class only handles problems without warhouse-capacity constraints"
        self.environment = environment
        self.learning_rate = learning_rate
        self.buffer_size = buffer_size
        self.initial_parameter = initial_parameter
        self.min_value = min_value
        self.max_value = max_value
        self.features = features
        self.param_dim = param_dim
        self.reset()

    def reset(self) :
        self.buffer = [torch.zeros(sum([product.state_dim for product in self.environment.products_list]),sum(self.param_dim)) for _ in range(self.buffer_size)]
        self.old_parameter = None
        self.parameter = None
        self.old_state = None
        self.state = None
        self.old_control = None
        self.control = None
        self.d_control_state = None
        self.d_control_parameter = None
        self.sum_of_squared_gradients = torch.zeros(sum(self.param_dim))

    def policy(self, t:int, state:torch.Tensor, parameter:torch.Tensor) -> torch.Tensor :
        state_idx = 0
        parameter_idx = 0
        control = torch.zeros(len(self.environment.products_list))
        for product_idx, product in enumerate(self.environment.products_list) :
            diff = torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[product_idx]], self.features[product_idx][t])-state[state_idx:state_idx+product.state_dim].sum()
            control[product_idx] = torch.where(diff>=0.0, diff, 0.0)
            state_idx += product.state_dim
            parameter_idx += self.param_dim[product_idx]
        return control
    
    def derivative_cost_state(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the left derivative
        """
        computed_jacobian = torch.zeros(len(state))
        #assert (jacobian(lambda input : self.environment.cost(t, input, control), state).shape == computed_jacobian.shape), "Error in shape c/s"
        
        for p_idx, product in enumerate(self.environment.products_list) :
            m = product.lifetime
            L = product.lifetime
            state_control = torch.cat([
                state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim],
                control[p_idx:p_idx+1]
            ])
            demand = product.demands[t]

            if(state_control[0] > demand) :
                computed_jacobian[self.environment.state_indexes[p_idx]] = product.outdating_cost[t]

            if(L==0) :
                if(state_control[:m].sum() > demand) :
                    computed_jacobian[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+m-1] += product.holding_cost[t]
                else :
                    computed_jacobian[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+m-1] -= product.penalty_cost[t]
            else :
                if(state_control[:m].sum() > demand) : 
                    computed_jacobian[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+m] += product.holding_cost[t]
                else :
                    computed_jacobian[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+m] -= product.penalty_cost[t]
        return computed_jacobian

    def derivative_cost_control(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the left derivative
        """
        computed_jacobian = torch.zeros(len(control))
        #assert (jacobian(lambda input : self.environment.cost(t, state, input), control).shape == computed_jacobian.shape), "Error in shape c/c"

        for p_idx, product in enumerate(self.environment.products_list) :
            m = product.lifetime
            L = product.leadtime
            state_control = torch.cat([
                state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim],
                control[p_idx:p_idx+1]
            ])
            demand = product.demands[t]

            computed_jacobian[p_idx] = product.purchase_cost[t]
            if(L == 0) :
                if(state_control[:m].sum() > demand) :
                    computed_jacobian[p_idx] = computed_jacobian[p_idx] + product.holding_cost[t]
                else :
                    computed_jacobian[p_idx] = computed_jacobian[p_idx] - product.penalty_cost[t]
            # we do not need to compute the gradient of the outdating cost w.r.t. the control because state_dim >= 1.
        return computed_jacobian


    # def transition_partial_derivative(i:int, j:int, state_control:torch.Tensor, m:int, demand:int) :
    #     """
    #     Here i and j are considered as math indexes thus starting from 1
    #     """
    #     if(i >= m) :
    #         return float((i+1) == j)
    #     if(j == i+1) :
    #         return float(state_control[i] > relu(demand-state_control[:i].sum()))
    #     if(j <= i) :
    #         return float((state_control[i] > demand-state_control[:i].sum()) and (demand >= state_control[:i].sum()))
    #     # here both i<m and j>i+1
    #     return 0.0

    def derivative_transition_state(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the left derivative
        """
        computed_jacobian = torch.zeros(len(state),len(state))
        # computed_jacobian_bis = torch.zeros(len(state),len(state))
        # assert (jacobian(lambda input : self.environment.transition(t, input, control), state).shape == computed_jacobian.shape), "Error in shape t/s"

        for p_idx, product in enumerate(self.environment.products_list) :
            m = product.lifetime
            L = product.leadtime
            demand = product.demands[t]
            
            state_control = torch.cat([
                state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim],
                control[p_idx:p_idx+1]
            ])

            for i in range(m-1, m+L-2) :
                computed_jacobian[
                    self.environment.state_indexes[p_idx]+i,
                    self.environment.state_indexes[p_idx]+i+1
                ] = 1.0
            for i in range(m-1) :
                if(state_control[i+1] > relu(demand-state_control[:i+1].sum()) and i+1 < m+L-1) :
                    computed_jacobian[
                        self.environment.state_indexes[p_idx]+i,
                        self.environment.state_indexes[p_idx]+i+1
                    ] = 1.0
                if(state_control[i+1] > demand-state_control[:i+1].sum()) and (demand >= state_control[:i+1].sum()) :
                    computed_jacobian[
                        self.environment.state_indexes[p_idx]+i,
                        self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+i+1
                    ] = 1.0

        #     for i in range(1,m+L) :
        #         for j in range(1,m+L) :
        #             computed_jacobian_bis[self.environment.state_indexes[p_idx]+i-1, self.environment.state_indexes[p_idx]+j-1] = GAPSI_Special.transition_partial_derivative(
        #                 i, j, state_control, m, demand
        #             )
        # assert (computed_jacobian == computed_jacobian_bis).all(), "Error in derivative_transition_state"
        return computed_jacobian

    def derivative_transition_control(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        """
        Implements the left derivative
        """
        computed_jacobian = torch.zeros(len(state),len(control))
        # computed_jacobian_bis = torch.zeros(len(state),len(control))
        # assert (jacobian(lambda input : self.environment.transition(t, state, input), control).shape == computed_jacobian.shape), "Error in shape t/c"

        for p_idx, product in enumerate(self.environment.products_list) :
            m = product.lifetime
            L = product.leadtime
            demand = product.demands[t]
            state_control = torch.cat([
                state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim],
                control[p_idx:p_idx+1]
            ])

            if(L > 0) :
                computed_jacobian[
                    self.environment.state_indexes[p_idx]+m+L-2,
                    p_idx
                ] = 1.0
            elif(m >= 2) :
                i = m+L-2
                if(state_control[i+1] > relu(demand-state_control[:i+1].sum())) :
                    computed_jacobian[
                        self.environment.state_indexes[p_idx]+i,
                        p_idx
                    ] = 1.0

            # j = m+L
            # for i in range(1,m+L) :
            #     computed_jacobian_bis[self.environment.state_indexes[p_idx]+i-1, p_idx] = GAPSI_Special.transition_partial_derivative(
            #         i, j, state_control, m, demand
            #     )
        return computed_jacobian
    
    def derivative_policy_state(self, t:int, state:torch.Tensor, parameter:torch.Tensor) :
        """
        Implements the right derivative
        """
        computed_jacobian = torch.zeros(len(self.environment.products_list), len(state))
        #assert (jacobian(lambda input : self.policy(t, input, parameter), state).shape == computed_jacobian.shape), "Error in shape p/s"

        parameter_idx = 0
        for p_idx, product in enumerate(self.environment.products_list) :
            computed_jacobian[p_idx, self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim] = -float(
                torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[p_idx]], self.features[p_idx][t])
                > state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim].sum()
            )
            parameter_idx += self.param_dim[p_idx]
        return computed_jacobian

    def derivative_policy_parameter(self, t:int, state:torch.Tensor, parameter:torch.Tensor) :
        """
        Implements the right derivative
        """
        computed_jacobian = torch.zeros(len(self.environment.products_list), len(parameter))
        #assert (jacobian(lambda input : self.policy(t, state, input), parameter).shape == computed_jacobian.shape), "Error in shape p/p"

        parameter_idx = 0
        for p_idx, product in enumerate(self.environment.products_list) :
            value = (
                torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[p_idx]], self.features[p_idx][t])
                - state[self.environment.state_indexes[p_idx]:self.environment.state_indexes[p_idx]+product.state_dim].sum()
            )
            computed_jacobian[p_idx,parameter_idx:parameter_idx+self.param_dim[p_idx]] = torch.where((value>0.0) | ((value==0.0) & (self.features[p_idx][t]>0.0)), self.features[p_idx][t], 0.0)
            parameter_idx += self.param_dim[p_idx]
        return computed_jacobian
    
    def get_control(self, t:int, state:torch.Tensor) -> torch.Tensor :
        """
        Updates the parameter and outputs the corresponding control
        """
        if(t>=1) :
            if(t>=2) :
                d_transition_control = self.derivative_transition_control(t-2, self.old_state, self.old_control)

                d_state_state_old = self.d_control_state
                d_state_state_old = torch.matmul(d_transition_control, d_state_state_old)
                d_state_state_old += self.derivative_transition_state(t-2, self.old_state, self.old_control)
    
            for b in reversed(range(2, self.buffer_size)) :
                if(t>=b+1) :
                    self.buffer[b] = torch.matmul(d_state_state_old, self.buffer[b-1])

            if(t>=2) :
                self.buffer[1] = self.d_control_parameter
                self.buffer[1] = torch.matmul(d_transition_control, self.buffer[1])


            self.d_control_state = self.derivative_policy_state(t-1, self.state, self.parameter)
            self.d_control_parameter = self.derivative_policy_parameter(t-1, self.state, self.parameter)
            d_cost_control = self.derivative_cost_control(t-1, self.state, self.control)

            gradient = self.derivative_cost_state(t-1, self.state, self.control)
            gradient += torch.matmul(d_cost_control, self.d_control_state)
            gradient = torch.matmul(gradient, sum(self.buffer[1:]))
            gradient += torch.matmul(d_cost_control, self.d_control_parameter)

            self.sum_of_squared_gradients += gradient**2
            #self.sum_of_squared_gradients = torch.maximum(self.sum_of_squared_gradients, torch.abs(gradient))
            #self.sum_of_squared_gradients = torch.ones_like(gradient)
            diameter = self.max_value-self.min_value

            inverse_of_RSSG = 1/torch.sqrt(self.sum_of_squared_gradients)
            inverse_of_RSSG[inverse_of_RSSG == float("inf")] = 0.0

            self.old_parameter = self.parameter
            self.parameter = torch.clip(self.parameter - self.learning_rate*diameter*inverse_of_RSSG*gradient, self.min_value, self.max_value)

        else :
            self.old_parameter = self.parameter
            self.parameter = self.initial_parameter

        self.old_state = self.state
        self.state = state

        self.old_control = self.control
        self.control = self.policy(t, state, self.parameter)
        return self.control