from trl import DPOTrainer
import torch, random
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch.nn.functional as F
import torch.nn as nn
from transformers import PreTrainedModel
from scipy.stats import dirichlet
from copy import deepcopy
class PrefARMTrainer(DPOTrainer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        training_args = kwargs["args"]
        self.pref_sample_p = training_args.pref_sample_p
        self.beta_obj = training_args.beta_obj
        self.obj_key = training_args.obj_key

        self.beta_obj.append(self.beta_obj[0])

        self.gamma = training_args.gamma 
        self.lambda_loss = training_args.lambda_loss 

        self.length_normalization = training_args.length_normalization
        self.num_step = 0

        self.prompt_config=training_args.prompt_config
  

        print( self.prompt_config )
        self.prompt_input_ids = {
                p_name: [self.tokenizer.encode_plus(p_value)["input_ids"] for p_value in p_values]  for p_name, p_values in self.prompt_config.items()
            }  
        if self.length_normalization:
            print('\nUsing length normalization. This is not default for training Autoregressive RM and should only be used for testing purposes!\n')
        if self.gamma != 0 or self.length_normalization:
            print(f'\nARM Trainer: gamma = {self.gamma}, length_normalization = {self.length_normalization}\n')

    def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
        batch = super().tokenize_row(feature, model)
        batch['labels'] = feature['labels']
        return batch

    @staticmethod
    def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        is_vision_model: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        concatenated_batch = DPOTrainer.concatenated_inputs(batch, is_encoder_decoder, is_vision_model, label_pad_token_id, padding_value)
        concatenated_batch['labels'] = batch['labels']
        return concatenated_batch

    def arm_loss(
        self,
        beta: float,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        obj_policy_chosen_last_hidden: torch.FloatTensor,
        obj_policy_rejected_last_hidden: torch.FloatTensor,
   
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the arm loss for a batch of policy model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the arm loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        gamma_logratios = self.gamma / beta 
        pi_logratios = pi_logratios.to(self.accelerator.device)
        logits = pi_logratios - gamma_logratios

        if self.loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-beta * logits) * self.label_smoothing
            )

        elif self.loss_type == "hinge":
            losses = torch.relu(1 - beta * logits)
        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']"
            )

        chosen_rewards = beta * policy_chosen_logps.to(self.accelerator.device).detach()
        rejected_rewards = beta * policy_rejected_logps.to(self.accelerator.device).detach()

        return losses, chosen_rewards, rejected_rewards
    def get_batch_entropy(self, 
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        average_log_prob: bool = False,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ):
        B, L, V = logits.shape
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id


        labels[labels == label_pad_token_id] = 0

        logp = F.log_softmax(logits, dim=-1)  # [B, L, V]
        p = torch.exp(logp)

        token_entropy = -(p * logp).sum(dim=-1)  # [B, L]

        token_entropy = 1 - token_entropy / torch.log(torch.tensor(V, device=logits.device))
        
        token_entropy = token_entropy * loss_mask

        if average_log_prob:
            seq_entropy = token_entropy.sum(dim=-1) / loss_mask.sum(dim=-1).clamp(min=1)  # [B]
        else:
            seq_entropy = token_entropy.sum(dim=-1)  # [B]

        return seq_entropy

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = (
            {
                "labels": concatenated_batch["concatenated_labels"],
                "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
                "output_hidden_states":True
            }
            if self.is_encoder_decoder
            else {}
        )
        all_outputs = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            output_hidden_states=True,
            use_cache=False,
            **model_kwargs,
        )

        all_logits = all_outputs.logits

        all_hidden_state = all_outputs.hidden_states[-1]


        all_logps, valid_length = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )


        all_seq_entropy = self.get_batch_entropy(
            all_logits,
            concatenated_batch["concatenated_labels"],
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )



        if self.length_normalization:
            all_logps = all_logps / valid_length


        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        chosen_seq_entropy = all_seq_entropy[:len_chosen]
        rejected_seq_entropy = all_seq_entropy[len_chosen:]

        closen_last_token_indices = valid_length[:len_chosen] -1
        rejected_last_token_indices = valid_length[len_chosen:] -1

        chosen_hidden_state = all_hidden_state[:len_chosen]
        rejected_hidden_state = all_hidden_state[len_chosen:]


        chosen_last_hidden = chosen_hidden_state[torch.arange(closen_last_token_indices.size()[0]), closen_last_token_indices]
        rejected_last_hidden = rejected_hidden_state[torch.arange(closen_last_token_indices.size()[0]), rejected_last_token_indices]

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_last_hidden, rejected_last_hidden, chosen_seq_entropy, rejected_seq_entropy)


    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the arm loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        self.num_step += 1

        if self.unseen_obj == "none":
            alphas = dirichlet.rvs( [self.pref_sample_p] * len(self.obj_key) )[0]
        else:
            alphas = dirichlet.rvs( [self.pref_sample_p] * (len(self.obj_key) - 1))[0]

        alphas = alphas.tolist()
        for idx, obj in enumerate(self.obj_key):
            if self.unseen_obj == obj:
                alphas.insert(idx, 0. )
        preference = torch.tensor(alphas)  

        device = model.device

        masked_prompt_embed = []

        if hasattr(model, "module"):
            
            for obj in self.obj_key:
                input_ids = torch.tensor(random.choice( self.prompt_input_ids[obj]))
                obj_embed = torch.mean(model.module.base_model.get_input_embeddings()(input_ids.to(device)), dim=0)
                masked_prompt_embed.append(obj_embed.unsqueeze(0))
            masked_prompt_embed = torch.cat(masked_prompt_embed, dim=0)

        else:

            for obj in self.obj_key:
                input_ids = torch.tensor(random.choice( self.prompt_input_ids[obj]))
                obj_embed = torch.mean(model.base_model.get_input_embeddings()(input_ids.to(device)), dim=0)
                masked_prompt_embed.append(obj_embed.unsqueeze(0))
            masked_prompt_embed = torch.cat(masked_prompt_embed, dim=0)
            

        for n, p in model.named_parameters():
            if 'pref_vec' in n:
                p.data = preference.to(p.device)
                p.requires_grad = False

        for n, p in model.named_parameters():
            if 'promptlora_pref_embed_vec' in n:

                p.data = masked_prompt_embed.to(p.device)
                p.requires_grad = False

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            chosen_last_hidden,
            rejected_last_hidden,
            chosen_seq_entropy, 
            rejected_seq_entropy

        ) = self.concatenated_forward(model, batch)



        loss_list = []
        bs = policy_chosen_logps.size()[0]
        concat_logps = torch.cat([policy_chosen_logps.unsqueeze(1), policy_rejected_logps.unsqueeze(1)], dim=1) # [bs, 2]
        concat_logits = torch.cat([policy_chosen_logits.unsqueeze(1), policy_rejected_logits.unsqueeze(1)], dim=1) # [bs, 2]
        
        concat_last_hidden = torch.cat([ chosen_last_hidden.unsqueeze(1), rejected_last_hidden.unsqueeze(1)], dim=1) # [bs, 2, hidden_size]

        concat_seq_entropy = torch.cat([chosen_seq_entropy.unsqueeze(1), rejected_seq_entropy.unsqueeze(1)], dim=1)

        obj_len = bs // len(self.obj_key)

        temp_obj_key = deepcopy(self.obj_key)
        temp_obj_key.append("global")
        
        for obj_idx, obj in enumerate(temp_obj_key):
            
            if obj != "global":
    
                chosen_idx = [batch['labels'][i][obj] for i in range(bs)]
                rejected_idx = [1 - batch['labels'][i][obj] for i in range(bs)]

            else:
                overall_chosen_scores, overall_rejected_scores = [],[]

                for chosen_score in batch["chosen_score"]:
                    overall_chosen_score = sum([each_score*each_pref for each_score, each_pref in zip(chosen_score, preference)])
                    overall_chosen_scores.append(overall_chosen_score)
                for rejected_score in batch["rejected_score"]:
                    overall_rejected_score = sum([each_score*each_pref for each_score, each_pref in zip(rejected_score, preference)])
                    overall_rejected_scores.append(overall_rejected_score)

                chosen_idx = [0 if c > r else 1 for c, r in zip(overall_chosen_scores, overall_rejected_scores)]
                rejected_idx = [1 if c > r else 0 for c, r in zip(overall_chosen_scores, overall_rejected_scores)]


            obj_policy_rejected_logps = concat_logps[range(bs), rejected_idx]
            obj_policy_chosen_logps = concat_logps[range(bs), chosen_idx]
            obj_policy_rejected_logits = concat_logits[range(bs), rejected_idx]
            obj_policy_chosen_logits = concat_logits[range(bs), chosen_idx]


            obj_policy_chosen_last_hidden = concat_last_hidden[range(bs), chosen_idx]
            obj_policy_rejected_last_hidden = concat_last_hidden[range(bs), rejected_idx]

            obj_policy_chosen_seq_entropy = concat_seq_entropy[range(bs), chosen_idx]
            obj_policy_rejected_seq_entropy = concat_seq_entropy[range(bs), rejected_idx]


    
            bt_losses, chosen_rewards, rejected_rewards = self.arm_loss(
                self.beta_obj[obj_idx],
                obj_policy_chosen_logps,
                obj_policy_rejected_logps,
                obj_policy_chosen_last_hidden,
                obj_policy_rejected_last_hidden,
       
            )

            losses = bt_losses

            reward_accuracies = (chosen_rewards > rejected_rewards).float()
            loss_list.append(losses.mean())
            prefix = "eval_" if train_eval == "eval" else ""

            metrics[f"{prefix}bt_loss_{obj}"] = bt_losses.mean().detach().cpu()
            # metrics[f"{prefix}cl_loss_{obj}"] = info_nce_losses.mean().detach().cpu()

            # metrics[f"{prefix}rewards_{obj}/chosen"] = chosen_rewards.mean().cpu()
            # metrics[f"{prefix}rewards_{obj}/rejected"] = rejected_rewards.mean().cpu()
            metrics[f"{prefix}rewards_{obj}/accuracies"] = reward_accuracies.mean().cpu()
            # metrics[f"{prefix}rewards_{obj}/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
            # metrics[f"{prefix}logps_{obj}/rejected"] = obj_policy_rejected_logps.detach().mean().cpu()
            # metrics[f"{prefix}logps_{obj}/chosen"] = obj_policy_chosen_logps.detach().mean().cpu()
            # metrics[f"{prefix}logits_{obj}/rejected"] = obj_policy_rejected_logits.detach().mean().cpu()
            # metrics[f"{prefix}logits_{obj}/chosen"] = obj_policy_chosen_logits.detach().mean().cpu()

        local_total_loss = sum([each_loss*each_pref for each_loss, each_pref in zip(loss_list[:len(preference)], preference)])

        global_loss = loss_list[len(preference)]

        total_loss =  local_total_loss + self.lambda_loss * global_loss
        return total_loss, metrics