import torch
from torch.autograd.functional import jacobian
from torch.nn.functional import relu
from typing import List

from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.environments.environment import Environment

class GAPSI_Autodiff(Algorithm) :
    def __init__(self, environment:Environment, param_dim: List[int], learning_rate:float, buffer_size:int, initial_parameter:torch.Tensor, min_value:torch.Tensor, max_value:torch.Tensor, features: List[torch.Tensor]) :
        self.environment = environment
        self.learning_rate = learning_rate
        self.buffer_size = buffer_size
        self.initial_parameter = initial_parameter
        self.min_value = min_value
        self.max_value = max_value
        self.features = features
        self.param_dim = param_dim
        self.reset()

    def reset(self) :
        self.buffer = [torch.zeros(sum([product.state_dim for product in self.environment.products_list]),sum(self.param_dim)) for _ in range(self.buffer_size)]
        self.old_parameter = None
        self.parameter = None
        self.old_state = None
        self.state = None
        self.old_control = None
        self.control = None
        self.d_control_state = None
        self.d_control_parameter = None
        self.sum_of_squared_gradients = torch.zeros(sum(self.param_dim))

    def policy(self, t:int, state:torch.Tensor, parameter:torch.Tensor) -> torch.Tensor :
        state_idx = 0
        parameter_idx = 0
        control = torch.zeros(len(self.environment.products_list))
        for product_idx, product in enumerate(self.environment.products_list) :
            diff = torch.dot(parameter[parameter_idx:parameter_idx+self.param_dim[product_idx]], self.features[product_idx][t])-state[state_idx:state_idx+product.state_dim].sum()
            control[product_idx] = relu(diff)#torch.where(diff>=0.0, diff, 0.0)
            state_idx += product.state_dim
            parameter_idx += self.param_dim[product_idx]
        return control

    def derivative_cost_state(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        return jacobian(lambda input : self.environment.cost(t, input, control), state)

    def derivative_cost_control(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        return jacobian(lambda input : self.environment.cost(t, state, input), control)

    def derivative_transition_state(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        return jacobian(lambda input : self.environment.transition(t, input, control), state)

    def derivative_transition_control(self, t:int, state:torch.Tensor, control:torch.Tensor) :
        return jacobian(lambda input : self.environment.transition(t, state, input), control)

    def derivative_policy_state(self, t:int, state:torch.Tensor, parameter:torch.Tensor) :
        return jacobian(lambda input : self.policy(t, input, parameter), state)

    def derivative_policy_parameter(self, t:int, state:torch.Tensor, parameter:torch.Tensor) :
        return jacobian(lambda input : self.policy(t, state, input), parameter)
    
    def get_control(self, t:int, state:torch.Tensor) -> torch.Tensor :
        """
        Updates the parameter and outputs the corresponding control
        """
        if(t>=1) :
            if(t>=2) :
                d_transition_control = self.derivative_transition_control(t-2, self.old_state, self.old_control)

                d_state_state_old = self.d_control_state
                d_state_state_old = torch.matmul(d_transition_control, d_state_state_old)
                d_state_state_old += self.derivative_transition_state(t-2, self.old_state, self.old_control)
    
            for b in reversed(range(2, self.buffer_size)) :
                if(t>=b+1) :
                    self.buffer[b] = torch.matmul(d_state_state_old, self.buffer[b-1])

            if(t>=2) :
                self.buffer[1] = self.d_control_parameter
                self.buffer[1] = torch.matmul(d_transition_control, self.buffer[1])

            self.d_control_state = self.derivative_policy_state(t-1, self.state, self.parameter)
            self.d_control_parameter = self.derivative_policy_parameter(t-1, self.state, self.parameter)
            d_cost_control = self.derivative_cost_control(t-1, self.state, self.control)

            gradient = self.derivative_cost_state(t-1, self.state, self.control)
            gradient += torch.matmul(d_cost_control, self.d_control_state)
            gradient = torch.matmul(gradient, sum(self.buffer[1:]))
            gradient += torch.matmul(d_cost_control, self.d_control_parameter)

            self.sum_of_squared_gradients += gradient**2
            #self.sum_of_squared_gradients = torch.maximum(self.sum_of_squared_gradients, torch.abs(gradient))
            #self.sum_of_squared_gradients = torch.ones_like(gradient)
            diameter = self.max_value-self.min_value

            inverse_of_RSSG = 1/torch.sqrt(self.sum_of_squared_gradients)
            inverse_of_RSSG[inverse_of_RSSG == float("inf")] = 0.0

            self.old_parameter = self.parameter
            self.parameter = torch.clip(self.parameter - self.learning_rate*diameter*inverse_of_RSSG*gradient, self.min_value, self.max_value)

        else :
            self.old_parameter = self.parameter
            self.parameter = self.initial_parameter

        self.old_state = self.state
        self.state = state

        self.old_control = self.control
        self.control = self.policy(t, state, self.parameter)
        return self.control