import torch
import numpy as np

from .base_trainer import BaseTrainer
from .losses import CulumativeRewardLoss, WeightedPolicyLoss

class DiTPolicyModelTrainer(BaseTrainer):
    def __init__(self, model, optimizer, criterion = WeightedPolicyLoss.mse_loss):
        """
        Initialize the policy model trainer with a model, optimizer, loss function, and device.
        
        Notice that DiT consist of three parts: Policy, V and Q. 
        So essentially, the passed in arguments have to be dictionaries.
        Args:
            model (dict): A dictionary containing the policy model, V model and Q model.
            optimizer (dict): A dictionary containing the optimizers for the policy model, V model and Q model.
            criterion (dict): A dictionary containing the loss functions for the policy model, V model and Q model.
            device (str): The device to use ('cpu' or 'cuda').
        """
        # We expect the keys of the dictionary to be "policy", "v" and "q".
        super().__init__(model, optimizer, criterion)
        self._policy_model = self.model["policy"]
        self._v_model = self.model["v"]
        self._q_model = self.model["q"]
        
        # We now only train the policy model.
        self._policy_optimizer = self.optimizer["policy"]
        self._policy_criterion = self.criterion
        
    def forward(self, batch):
        """
        Perform a single training step for the models.
        Args:
            batch (tuple): A batch of data.
            It shoule be in the following format:
            batch:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)   
            }
        
        Returns:
            float: The average loss value for the step.
        """
        batch_size, horizon, _ = batch["context_states"].shape
        
        output = {}

        query_state_indices = np.random.choice(range(horizon),batch_size,replace=True)
        query_states = batch["context_states"][torch.arange(batch_size), query_state_indices, :]
        batch["query_states"] = query_states
        batch["query_state_indices"] = query_state_indices
        
        self._v_model.eval()
        self._q_model.eval()
        output["v_values"] = self._v_model(batch)
        output["q_values"] = self._q_model(batch)
        self._v_model.train()
        self._q_model.train()
        output["policy_values"] = self._policy_model(batch)
        
        return output
        
    def train_step(self, batch):
        batch_size, horizon, _ = batch["context_states"].shape
        
        output = self.forward(batch)
        loss = {}
        policy_weights = (output["q_values"].detach() - output["v_values"].detach())[torch.arange(batch_size),batch["query_state_indices"]]
        
        pred_actions = output["policy_values"]
        true_actions = batch["context_actions"][torch.arange(batch_size),batch["query_state_indices"]]
        true_actions = true_actions.unsqueeze(1).repeat(1, horizon, 1)
        policy_weights = policy_weights.repeat_interleave(horizon)
        true_actions = true_actions.reshape(-1, self._policy_model._action_dim )
        pred_actions = pred_actions.reshape(-1, self._policy_model._action_dim )
        loss["policy_loss"] = self._policy_criterion(pred_actions, true_actions,  torch.exp(policy_weights/self._policy_model._config.weight_scale))
        # loss["policy_loss"] = self._policy_criterion(output["policy_values"][torch.arange(batch_size),batch["query_state_indices"]-1],
        #                                              batch["context_actions"][torch.arange(batch_size),batch["query_state_indices"]], 
        #                                              torch.exp(policy_weights/1.0))
        
        
        
        loss = self.update(loss)
        return loss
    
    def update(self, loss):
        itemed_loss = {}
        
        self._policy_optimizer.zero_grad()

        loss["policy_loss"].backward()

        self._policy_optimizer.step()
        
        itemed_loss["policy_loss"] = loss["policy_loss"].item()
        
        return itemed_loss
        
    
    def validate_step(self, batch):
        output = self.forward(batch)
        loss = {}
        policy_weights = output["q_values"].detach() - output["v_values"].detach()
        loss["policy_loss"] = self._policy_criterion(output["policy_values"], batch["context_actions"], torch.exp(policy_weights/1.5)).item()
        
        return loss
    
    def save_checkpoint(self, filepath, info={}):
        state = {}
        for key, value in info.items():
            state[key] = value
        state["policy_state_dict"] = self._policy_model.state_dict()
        state["policy_optimizer_dict"] = self._policy_optimizer.state_dict()
        state["v_state_dict"] = self._v_model.state_dict()
        state["q_state_dict"] = self._q_model.state_dict()
        torch.save(state, filepath)
        
    def load_checkpoint(self, filepath, optimizer=False):
        print("=> Loading checkpoint")
        state = torch.load(filepath)
        self._policy_model.load_state_dict(state["policy_state_dict"])
        self._v_model.load_state_dict(state["v_state_dict"])
        self._q_model.load_state_dict(state["q_state_dict"])
        if optimizer and "policy_optimizer_dict" in state:
            self._policy_optimizer.load_state_dict(state["policy_optimizer_dict"])
        return state
    
class PreferenceDiTPolicyModelTrainer(DiTPolicyModelTrainer):
    def __init__(self, model, optimizer, criterion = WeightedPolicyLoss.mse_loss, preference_model = None):
        """
        Initialize the policy model trainer with a model, optimizer, loss function, and device.
        
        Notice that DiT consist of three parts: Policy, V and Q. 
        So essentially, the passed in arguments have to be dictionaries.
        Args:
            model (dict): A dictionary containing the policy model, V model and Q model.
            optimizer (dict): A dictionary containing the optimizers for the policy model, V model and Q model.
            criterion (dict): A dictionary containing the loss functions for the policy model, V model and Q model.
            device (str): The device to use ('cpu' or 'cuda').
        """
        # We expect the keys of the dictionary to be "policy", "v" and "q".
        super().__init__(model, optimizer, criterion)
        self._preference_model = preference_model
        
    def train_step(self, batch):
        """
        Batch is not a dictionary as the following structure:
        {
            batch["traj_1"]:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)   
            }
            batch["traj_2"]:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)  
        },
        we need to labe the rewards with the preference model and convert into the format that we need for
        training value models, that is
        {
            "context_states": Tensor of shape (batch_size, context_length, state_dim)
            "context_actions": Tensor of shape (batch_size, context_length, action_dim)
            "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
            "context_rewards": Tensor of shape (batch_size, context_length, 1)
            "query_states": Tensor of shape (batch_size, state_dim)
            "optimal_actions": Tensor of shape (batch_size, action_dim)   
        }
        """
        batch_size, horizon, _ = batch["traj_1"]["context_states"].shape
        converted_batch = {}
        converted_batch["context_states"] = torch.cat([batch["traj_1"]["context_states"], batch["traj_2"]["context_states"]], dim=0)
        converted_batch["context_actions"] = torch.cat([batch["traj_1"]["context_actions"], batch["traj_2"]["context_actions"]], dim=0)
        converted_batch["context_next_states"] = torch.cat([batch["traj_1"]["context_next_states"], batch["traj_2"]["context_next_states"]], dim=0)
        converted_batch["context_rewards"] = torch.zeros((batch_size * 2, horizon, 1)).to(self._preference_model._device)
        
        self._preference_model.eval()
        for i in range(horizon):
            query_states = batch["traj_1"]["context_states"][:, i, :]
            optimal_actions = batch["traj_1"]["context_actions"][:, i, :]
            traj_1_rewards = self._preference_model(batch, query_states, optimal_actions, test=True).detach()
            
            query_states = batch["traj_2"]["context_states"][:, i, :]
            optimal_actions = batch["traj_2"]["context_actions"][:, i, :]
            
            traj_2_rewards = self._preference_model(batch, query_states, optimal_actions, test=True).detach()
            
            batched_rewards = torch.cat([traj_1_rewards, traj_2_rewards], dim=0)
            converted_batch["context_rewards"][:, i, :] = batched_rewards
        self._preference_model.train()
        loss = super().train_step(converted_batch)
        
        return loss