import logging
import torch
from torch import autocast
from contextlib import nullcontext
import torch.nn.functional as F
import torch.nn as nn
from typing import Union, Literal, Optional
import torch.distributed as dist
from accelerate import PartialState
from transformers import PreTrainedModel
from trl import DPOTrainer

logger = logging.getLogger(__name__)

class LeanFinderDPOTrainer(DPOTrainer):
    def __init__(self, dual_loss=False, contrastive_loss_temp=0.2, *args, **kwargs):
        self.dual_loss = dual_loss
        super(LeanFinderDPOTrainer, self).__init__(*args, **kwargs)
        self.is_ddp = dist.is_initialized()
        if self.is_ddp:
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.contrastive_loss_temp = contrastive_loss_temp
        

    def tokenize_row(self, row, processing_class):
        tokenizer = processing_class
        query_max_length = 610
        corpus_max_length = 210

        prompt_input_ids = tokenizer(
            row["prompt"],
            padding=False, 
            truncation=True,
            max_length=query_max_length-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )["input_ids"]
        chosen_input_ids = tokenizer(
            row["chosen"],
            padding=False, 
            truncation=True,
            max_length=corpus_max_length-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )["input_ids"]
        rejected_input_ids = tokenizer(
            row["rejected"],
            padding=False, 
            truncation=True,
            max_length=corpus_max_length-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )["input_ids"]

        prompt_input_ids = [q + [tokenizer.eos_token_id] for q in prompt_input_ids]
        chosen_input_ids = [d + [tokenizer.eos_token_id] for d in chosen_input_ids]
        rejected_input_ids = [d + [tokenizer.eos_token_id] for d in rejected_input_ids]

        return {
            "prompt_input_ids": prompt_input_ids,
            "chosen_input_ids": chosen_input_ids,
            "rejected_input_ids": rejected_input_ids,
        }

    def _prepare_dataset(self, dataset, processing_class, args, dataset_name):
        if self.dual_loss:
            return dataset
        
        map_kwargs = {
            "num_proc": args.dataset_num_proc,
            "writer_batch_size": 10,
            "batched": True,
            "batch_size": 2,
            "desc": f"Tokenizing {dataset_name} dataset"
        }

        with PartialState().main_process_first():
            dataset = dataset.map(
                self.tokenize_row,
                remove_columns=["chosen", "rejected"],
                fn_kwargs={
                    "processing_class": processing_class
                },
                **map_kwargs,
            )

        return dataset

    def compute_similarity(self, q_reps, p_reps):
        return torch.matmul(q_reps, p_reps.transpose(0, 1))

    def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)

        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors

    def _pooling(self, last_hidden_state, attention_mask):
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            reps = last_hidden_state[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
        reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
        return reps
    
    def model_dpo_forward(
        self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
    ) -> dict[str, torch.Tensor]:
        query_forward_input = {
            "input_ids": batch["prompt_input_ids"],
            "attention_mask": batch["prompt_attention_mask"]
        }
        chosen_forward_input = {
            "input_ids": batch["chosen_input_ids"],
            "attention_mask": batch["chosen_attention_mask"]
        }
        rejected_forward_input = {
            "input_ids": batch["rejected_input_ids"],
            "attention_mask": batch["rejected_attention_mask"]
        }
        query_hidden_states = model(**query_forward_input, return_dict=True).last_hidden_state
        chosen_hidden_states = model(**chosen_forward_input, return_dict=True).last_hidden_state
        rejected_hidden_states = model(**rejected_forward_input, return_dict=True).last_hidden_state

        query_reps = self._pooling(query_hidden_states, query_forward_input["attention_mask"])
        chosen_reps = self._pooling(chosen_hidden_states, chosen_forward_input["attention_mask"])
        rejected_reps = self._pooling(rejected_hidden_states, rejected_forward_input["attention_mask"])
        
        chosen_scores = torch.sum(query_reps * chosen_reps, dim=-1)
        rejected_scores = torch.sum(query_reps * rejected_reps, dim=-1)

        pair_scores = torch.stack([chosen_scores, rejected_scores], dim=-1)
        pair_log_probs = torch.log_softmax(pair_scores, dim=-1)

        return {
            "chosen_logps": pair_log_probs[:, 0],
            "rejected_logps": pair_log_probs[:, 1],
            "chosen_scores": chosen_scores,
            "rejected_scores": rejected_scores,
        }

    def model_contrastive_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]) -> dict[str, torch.Tensor]:
        query_forward_input = {
            "input_ids": batch["contrastive_query_input_ids"],
            "attention_mask": batch["contrastive_query_attention_mask"]
        }
        passage_forward_input = {
            "input_ids": batch["contrastive_passage_input_ids"],
            "attention_mask": batch["contrastive_passage_attention_mask"]
        }

        query_hidden_states = model(**query_forward_input, return_dict=True).last_hidden_state
        passage_hidden_states = model(**passage_forward_input, return_dict=True).last_hidden_state

        query_reps = self._pooling(query_hidden_states, query_forward_input["attention_mask"])
        passage_reps = self._pooling(passage_hidden_states, passage_forward_input["attention_mask"])

        query_reps_gathered = self._dist_gather_tensor(query_reps)
        passage_reps_gathered = self._dist_gather_tensor(passage_reps)

        return {
            "query_reps": query_reps_gathered,
            "passage_reps": passage_reps_gathered
        }
    
    def contrastive_loss(self, query_reps: torch.Tensor, passage_reps: torch.Tensor) -> torch.Tensor:
        scores = self.compute_similarity(query_reps, passage_reps)
        scores = scores.view(query_reps.size(0), -1)

        target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
        target = target * (passage_reps.size(0) // query_reps.size(0))

        loss = self.cross_entropy(scores / self.contrastive_loss_temp, target)

        if self.is_ddp:
            loss = loss * self.world_size
        return loss

    def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]:
        compte_ref_context_manager = (
            autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
        )
        with torch.no_grad(), compte_ref_context_manager:
            with self.null_ref_context():
                ref_model_output = self.model_dpo_forward(self.model, batch, is_ref_model=True)
        return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]

    def dpo_loss(
        self,
        chosen_logps: torch.FloatTensor,
        rejected_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        ref_rejected_logps: torch.FloatTensor,
        loss_type: str = "sigmoid",
        model_output: dict[str, torch.FloatTensor] = None,
    ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        device = self.accelerator.device

        logratios = chosen_logps - rejected_logps
        ref_logratios = ref_chosen_logps - ref_rejected_logps

        logratios = logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = logratios - ref_logratios

        if loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']"
            )

        chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
        rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()

        return losses, chosen_rewards, rejected_rewards

    def get_batch_loss_metrics(
        self,
        model: Union[PreTrainedModel, nn.Module],
        batch: dict[str, Union[list, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ) -> tuple[torch.Tensor, dict[str, float]]:     
        metrics = {}

        model_output_for_dpo = self.model_dpo_forward(model, batch)
        ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)

        dpo_losses, chosen_rewards, rejected_rewards = self.dpo_loss(
                model_output_for_dpo["chosen_logps"],
                model_output_for_dpo["rejected_logps"],
                ref_chosen_logps,
                ref_rejected_logps,
                "sigmoid",
                model_output_for_dpo,
            )

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        if self.dual_loss:
            model_output_for_contrastive = self.model_contrastive_forward(model, batch)
            contrastive_losses = self.contrastive_loss(model_output_for_contrastive["query_reps"], model_output_for_contrastive["passage_reps"])
            losses = dpo_losses + self.args.rpo_alpha * contrastive_losses
        else:
            losses = dpo_losses

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
        metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
        metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
        metrics[f"{prefix}rewards/margins"] = (
            self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
        )
        metrics[f"{prefix}logps/chosen"] = (
            self.accelerator.gather_for_metrics(model_output_for_dpo["chosen_logps"]).detach().mean().item()
        )
        metrics[f"{prefix}logps/rejected"] = (
            self.accelerator.gather_for_metrics(model_output_for_dpo["rejected_logps"]).detach().mean().item()
        )
        metrics[f"{prefix}raw_scores/chosen"] = (
            self.accelerator.gather_for_metrics(model_output_for_dpo["chosen_scores"]).detach().mean().item()
        )
        metrics[f"{prefix}raw_scores/rejected"] = (
            self.accelerator.gather_for_metrics(model_output_for_dpo["rejected_scores"]).detach().mean().item()
        )
        if self.dual_loss:
            metrics[f"{prefix}contrastive_loss"] = (
                self.accelerator.gather_for_metrics(contrastive_losses/self._dist_loss_scale_factor).detach().mean().item()
            )

        return losses.mean(), metrics