import torch

from .base_trainer import BaseTrainer
from .losses import CulumativeRewardLoss

class ValueModelTrainer(BaseTrainer):
    def __init__(self, model, optimizer, criterion = {"v":CulumativeRewardLoss.mse_loss, "q":CulumativeRewardLoss.mse_loss}):
        """
        Initialize the value model trainer with a model, optimizer, loss function, and device.
        
        Notice that DiT consist of three parts: Policy, V and Q. 
        So essentially, the passed in arguments have to be dictionaries.
        Args:
            model (dict): A dictionary containing the policy model, V model and Q model.
            optimizer (dict): A dictionary containing the optimizers for the policy model, V model and Q model.
            criterion (dict): A dictionary containing the loss functions for the policy model, V model and Q model.
            device (str): The device to use ('cpu' or 'cuda').
        """
        # We expect the keys of the dictionary to be "policy", "v" and "q".
        super().__init__(model, optimizer, criterion)
        self._v_model = self.model["v"]
        self._q_model = self.model["q"]
        
        self._v_optimizer = self.optimizer["v"]
        self._q_optimizer = self.optimizer["q"]
        
        self._v_criterion = self.criterion["v"]
        self._q_criterion = self.criterion["q"]
        
    
    def forward(self, batch):
        output = {}
        output["v_values"] = self._v_model(batch)
        output["q_values"] = self._q_model(batch)
        
        return output
    
    def train_step(self, batch):
        output = self.forward(batch)
        loss = {}
        loss["v_loss"] = self._v_criterion(output["v_values"], batch["context_rewards"], gamma=self._v_model._gamma)
        loss["q_loss"] = self._q_criterion(output["q_values"], batch["context_rewards"], gamma=self._q_model._gamma)
        
        loss = self.update(loss)
        
        return loss
    
    def update(self, loss):
        itemed_loss = {}
        
        self._v_optimizer.zero_grad()
        loss["v_loss"].backward()
        self._v_optimizer.step()
        itemed_loss["v_loss"] = loss["v_loss"].item()
        
        self._q_optimizer.zero_grad()
        loss["q_loss"].backward()
        self._q_optimizer.step()
        itemed_loss["q_loss"] = loss["q_loss"].item()
        
        return itemed_loss
    
    def validate_step(self, batch):
        output = self.forward(batch)
        loss = {}
        loss["v_loss"] = self._v_criterion(output["v_values"], batch["context_rewards"], gamma=self._v_model._gamma).item()
        loss["q_loss"] = self._q_criterion(output["q_values"], batch["context_rewards"], gamma=self._q_model._gamma).item()
        
        return loss
    
    def save_checkpoint(self, filepath, info={}):
        state = {}
        for key, value in info.items():
            state[key] = value
            
        state["v_state_dict"] = self._v_model.state_dict()
        state["v_optimizer_dict"] = self._v_optimizer.state_dict()
        state["q_state_dict"] = self._q_model.state_dict()
        state["q_optimizer_dict"] = self._q_optimizer.state_dict()
        
        torch.save(state, filepath)
        
    def load_checkpoint(self, filepath, optimizer=False):
        print("=> Loading checkpoint")
        state = torch.load(filepath)
        self._v_model.load_state_dict(state["v_state_dict"])
        self._q_model.load_state_dict(state["q_state_dict"])
        if optimizer and "v_optimizer_dict" in state:
            self._v_optimizer.load_state_dict(state["v_optimizer_dict"])
        if optimizer and "q_optimizer_dict" in state:
            self._q_optimizer.load_state_dict(state["q_optimizer_dict"])
        return state

class PreferenceValueModelTrainer(ValueModelTrainer):
    def __init__(self, model, optimizer, criterion = {"v":CulumativeRewardLoss.mse_loss, "q":CulumativeRewardLoss.mse_loss}, preference_model=None):
        """
        Initialize the value model trainer with a model, optimizer, loss function, and device.
        
        Notice that DiT consist of three parts: Policy, V and Q. 
        So essentially, the passed in arguments have to be dictionaries.
        Args:
            model (dict): A dictionary containing the policy model, V model and Q model.
            optimizer (dict): A dictionary containing the optimizers for the policy model, V model and Q model.
            criterion (dict): A dictionary containing the loss functions for the policy model, V model and Q model.
            device (str): The device to use ('cpu' or 'cuda').
        """
        # We expect the keys of the dictionary to be "policy", "v" and "q".
        super().__init__(model, optimizer, criterion)
        self._preference_model = preference_model
    
    def train_step(self, batch):
        """
        Batch is not a dictionary as the following structure:
        {
            batch["traj_1"]:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)   
            }
            batch["traj_2"]:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)  
        },
        we need to labe the rewards with the preference model and convert into the format that we need for
        training value models, that is
        {
            "context_states": Tensor of shape (batch_size, context_length, state_dim)
            "context_actions": Tensor of shape (batch_size, context_length, action_dim)
            "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
            "context_rewards": Tensor of shape (batch_size, context_length, 1)
            "query_states": Tensor of shape (batch_size, state_dim)
            "optimal_actions": Tensor of shape (batch_size, action_dim)   
        }
        """
        batch_size, horizon, _ = batch["traj_1"]["context_states"].shape
        converted_batch = {}
        converted_batch["context_states"] = torch.cat([batch["traj_1"]["context_states"], batch["traj_2"]["context_states"]], dim=0)
        converted_batch["context_actions"] = torch.cat([batch["traj_1"]["context_actions"], batch["traj_2"]["context_actions"]], dim=0)
        converted_batch["context_next_states"] = torch.cat([batch["traj_1"]["context_next_states"], batch["traj_2"]["context_next_states"]], dim=0)
        converted_batch["context_rewards"] = torch.zeros((batch_size * 2, horizon, 1)).to(self._preference_model._device)
        
        self._preference_model.eval()
        for i in range(horizon):
            query_states = batch["traj_1"]["context_states"][:, i, :]
            optimal_actions = batch["traj_1"]["context_actions"][:, i, :]
            traj_1_rewards = self._preference_model(batch, query_states, optimal_actions, test=True).detach()
            
            query_states = batch["traj_2"]["context_states"][:, i, :]
            optimal_actions = batch["traj_2"]["context_actions"][:, i, :]
            
            traj_2_rewards = self._preference_model(batch, query_states, optimal_actions, test=True).detach()
            
            batched_rewards = torch.cat([traj_1_rewards, traj_2_rewards], dim=0)
            converted_batch["context_rewards"][:, i, :] = batched_rewards

        loss = super().train_step(converted_batch)
        
        return loss