import torch

from .base_trainer import BaseTrainer
from .losses import PreferenceLoss

class RewardModelTrainer(BaseTrainer):
    def __init__(self, model, optimizer, criterion=PreferenceLoss.loglikelyhood_loss):
        """
        Initialize the reward model trainer with a model, optimizer, loss function, and device.

        Args:
            model (torch.nn.Module): The reward model to train.
            optimizer (torch.optim.Optimizer): The optimizer for training.
            criterion (torch.nn.Module): The loss function.
            device (str): The device to use ('cpu' or 'cuda').
        """
        super().__init__(model, optimizer, criterion)
        
    def forward(self, batch):
        """
        Perform a single training step for the reward model.

        Args:
            batch (tuple): A batch of data.
            It shoule be in the following format:
            batch["traj_1"]:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)   
            }
            batch["traj_2"]:{
                "context_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                "context_rewards": Tensor of shape (batch_size, context_length, 1)  
                

        Returns:
            float: The average loss value for the step.
        """
        ### batch has already been in the format that we need to train the model.
        ### We only need to find the query states and actions.
        traj_1_query_states = batch['traj_1']['context_states']
        traj_1_query_actions = batch['traj_1']['context_actions']
        
        traj_2_query_states = batch['traj_2']['context_states']
        traj_2_query_actions = batch['traj_2']['context_actions']
        
        context_lengh = traj_1_query_states.shape[1]
        
        traj_1_reward_sum = 0
        traj_2_reward_sum = 0
        for i in range(context_lengh):
            current_traj_1_query_state = traj_1_query_states[:, i, :] # [batch_size, state_dim]
            current_traj_1_query_action = traj_1_query_actions[:, i, :] # [batch_size, action_dim]
            current_traj_2_query_state = traj_2_query_states[:, i, :]
            current_traj_2_query_action = traj_2_query_actions[:, i, :]
            
            traj_1_reward = self.model(batch, current_traj_1_query_state, current_traj_1_query_action)
            traj_2_reward = self.model(batch, current_traj_2_query_state, current_traj_2_query_action)
            # traj_*_reward is a tensor of shape (batch_size, horizon-1,1), we can try different things here.
            traj_1_reward_sum += traj_1_reward
            traj_2_reward_sum += traj_2_reward
        
        return traj_1_reward_sum, traj_2_reward_sum
        
    def train_step(self, batch):
        traj_1_reward_sum, traj_2_reward_sum = self.forward(batch)
        loss = self.criterion(traj_1_reward_sum, traj_2_reward_sum)
        loss = self.update(loss)
        print(f"reward loss: {loss}")
        return {"reward loss": loss}
    
    def validate_step(self, batch):
        traj_1_reward_sum, traj_2_reward_sum = self.forward(batch)
        loss = self.criterion(traj_1_reward_sum, traj_2_reward_sum)
        
        return loss.item()
    
    def save_checkpoint(self, filepath, info={}):
        # we expect the info contains some of the information that we would like to save.
        
        state = {}
        for key, value in info.items():
            state[key] = value
        state["model_state_dict"] = self.model.state_dict()
        state["optimizer_state_dict"] = self.optimizer.state_dict()
        torch.save(state, filepath)
    
    def load_checkpoint(self, filepath, optimizer=False):
        print("=> Loading checkpoint")
        state = torch.load(filepath)
        self.model.load_state_dict(state["model_state_dict"])
        if optimizer and "optimizer_state_dict" in state:
            self.optimizer.load_state_dict(state["optimizer_state_dict"])
        return state