from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.distributed as dist
from torch import LongTensor, Tensor, FloatTensor
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.common.templates import DATA_TYPE_DICT
from src.model.modeling_utils import concat_all_gather


class AlignmentLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]
        self.num_labels = 2
        if (
            self.model_args.rank_pooler_type == "hidden_state"
            or self.model_args.rank_type == "cls"
        ):
            self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
        self.post_init()

    def gather_logits(self, logits, output_ids, token_mask):
        probs = (
            torch.gather(logits, dim=-1, index=output_ids[:, :, None]).squeeze(-1)
            * token_mask
        )
        return probs

    def get_score(self, probs, token_mask):
        scores = probs.sum(-1) / (token_mask.sum(-1))
        return scores

    def get_sequence_prob_score(self, input_ids, shift_logits, shift_labels):
        token_mask = shift_labels.ne(-100).float()
        output_ids = input_ids[..., 1:].contiguous()
        probs = self.gather_logits(
            F.log_softmax(shift_logits, dim=-1), output_ids, token_mask
        )
        scores = self.get_score(probs, token_mask)
        return scores

    def get_pooled_logits(self, input_ids, hidden_states, scorer):
        cls_logits = scorer(hidden_states)
        batch_size = hidden_states.shape[0]
        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError(
                "Cannot handle batch sizes > 1 if no padding token is defined."
            )
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            sequence_lengths = (
                torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
            ).to(cls_logits.device)
        pooled_logits = cls_logits[
            torch.arange(batch_size, device=cls_logits.device), sequence_lengths
        ]
        return pooled_logits

    def get_eos_score(self, input_ids, hidden_states, scorer):
        pooled_logits = self.get_pooled_logits(input_ids, hidden_states, scorer)
        pooled_logits = F.log_softmax(pooled_logits, dim=-1)[:, 1]
        return pooled_logits

    def listwise_loss(self, scores, rewards):
        s_i = scores.unsqueeze(-1)  # Shape: [list_length, 1]
        s_j = (
            scores.unsqueeze(-2)
            if not self.model_args.rank_detach
            else scores.unsqueeze(-2).detach().clone()
        )
        r_i = rewards.unsqueeze(-1)
        r_j = rewards.unsqueeze(-2)
        mask = (r_i > r_j).float()
        differences = (mask * (s_j - s_i)).sum()  # [list_length, list_length]
        loss = F.softplus(differences)
        return loss

    def stable_alignment_origin(
        self, logits: torch.Tensor, labels: torch.Tensor, feedback_scores: torch.Tensor
    ) -> torch.Tensor:
        # Calculate the SFT loss
        sorted_ratings, indices = torch.sort(feedback_scores.squeeze(), descending=True)
        best_idx = indices[0] if indices.dim() != 0 else indices.item()
        best_score = sorted_ratings[0] if indices.dim() != 0 else sorted_ratings.item()
        loss_fct = CrossEntropyLoss(ignore_index=-100)

        # Calculate the penalty from low-rating responses.
        batch_losses = []
        for logit, label in zip(logits, labels):
            batch_losses.append(
                loss_fct(logit.view(-1, logits.size(-1)), label.view(-1))
            )
        batch_loss = torch.stack(batch_losses, dim=0)

        # Modulate the penalty by the difference in ratings.
        min_loss = batch_loss[best_idx]
        neg_losses = []
        if indices.dim() != 0 and indices.size(-1) > 1:
            for idx in indices[1:]:
                margin = (
                    (best_score - sorted_ratings[idx]) / 4 * self.model_args.rank_margin
                )
                neg_loss = min_loss - batch_loss[idx] + margin
                neg_losses.append(neg_loss)
        if dist.get_rank() == 0:
            print(neg_losses)

        if len(neg_losses) > 0:
            neg_losses_ts = torch.stack(neg_losses)
            diff = torch.max(neg_losses_ts.mean(), torch.tensor(0.0).cuda())
        else:
            diff = torch.tensor(0.0).cuda()

        return min_loss, diff

    def stable_alignment(self, scores, rewards):
        s_i = scores.unsqueeze(-1)  # Shape: [batch_size, list_length, 1]
        s_j = (
            scores.unsqueeze(-2)
            if not self.model_args.rank_detach
            else scores.unsqueeze(-2).detach().clone()
        )
        r_i = rewards.unsqueeze(-1)
        r_j = rewards.unsqueeze(-2)
        mask = (r_i > r_j).float()
        margin = (r_i - r_j) / 4 * self.model_args.rank_margin
        differences = mask * (s_j - s_i + margin)
        differences = torch.clamp(differences, min=0).sum(-1).mean()
        return differences

    # L_rank = − \sum_{r_{tea} > r_{stu}} (r_{tea} − r_{stu}) min (0, p_{tea} − p_{stu})
    def pangu_loss(self, scores, rewards):
        s_i = scores.unsqueeze(-1)  # Shape: [list_length, 1]
        s_j = (
            scores.unsqueeze(-2)
            if not self.model_args.rank_detach
            else scores.unsqueeze(-2).detach().clone()
        )
        r_i = rewards.unsqueeze(-1)
        r_j = rewards.unsqueeze(-2)
        mask = (r_i > r_j).float() * (r_i - r_j)
        differences = mask * (s_j - s_i)
        differences = torch.clamp(differences, min=0).sum(-1).mean()
        return differences

    def cls_loss(self, scores, cls_labels):
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(scores.view(-1, self.num_labels), cls_labels.view(-1))
        return loss

    def aft_binary_loss(self, scores, rw_scores):
        pos_indexs = []
        neg_indexs = []
        max_scores = torch.max(rw_scores, dim=-1)[0]
        tmp = 0
        lower_bounad = 1000
        positive_lower_boundary = 10000
        for idx, rw_score in enumerate(rw_scores):
            if rw_score.item() == max_scores.item():
                pos_indexs.append(idx)
                positive_lower_boundary = min(
                    scores[idx].item(), positive_lower_boundary
                )
            else:
                neg_indexs.append(idx)
            lower_bounad = min(scores[idx].item(), lower_bounad)

        for neg_index in neg_indexs:
            for pos_index in pos_indexs:
                tmp += torch.exp((scores[neg_index].detach() - scores[pos_index]))
                # if self.args.boundary:
                #     tmp += torch.exp(
                #         2 * positive_lower_boundary
                #         - 2 * self.args.beta
                #         - scores[pos_index].item()
                #         - scores[neg_index]
                #     )
        if tmp != 0:
            loss = torch.log(1 + tmp)
        else:
            loss = 0

        return loss

    def forward(
        self,
        input_ids: LongTensor = None,
        attention_mask: Tensor | None = None,
        rewards: FloatTensor | None = None,
        position_ids: LongTensor | None = None,
        past_key_values: List[FloatTensor] | None = None,
        inputs_embeds: FloatTensor | None = None,
        labels: LongTensor | None = None,
        use_cache: bool | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> Tuple | CausalLMOutputWithPast:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            None,  # labels
            use_cache,
            output_attentions,
            True,
            return_dict,
        )
        if labels is not None:
            logits = outputs["logits"]
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_labels = shift_labels.to(shift_logits.device)
            if self.model_args.rank_type != "cls":
                if self.model_args.rank_pooler_type == "seq_prob":
                    scores = self.get_sequence_prob_score(
                        input_ids, shift_logits, shift_labels
                    )
                else:
                    scores = self.get_eos_score(
                        input_ids, outputs.hidden_states[-1], self.score
                    )
            else:
                scores = self.get_pooled_logits(
                    input_ids, outputs.hidden_states[-1], self.score
                )
            if self.model_args.rank_type == "stable":
                rank_loss = self.stable_alignment(scores, rewards)
            elif self.model_args.rank_type == "pangu":
                rank_loss = self.pangu_loss(scores, rewards)
            elif self.model_args.rank_type == "cross":
                rank_loss = self.listwise_loss(scores, rewards)
            elif self.model_args.rank_type == "aft":
                rank_loss = self.aft_binary_loss(scores, rewards)
            elif self.model_args.rank_type == "cls":
                rank_loss = self.cls_loss(scores, rewards)

            if self.model_args.rank_type != "stable_origin":
                loss_fct = CrossEntropyLoss()
                sample_mask = rewards > 0
                shift_logits = shift_logits[sample_mask]
                shift_labels = shift_labels[sample_mask]
                lm_loss = loss_fct(
                    shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
                )
            else:
                lm_loss, rank_loss = self.stable_alignment_origin(
                    shift_logits, shift_labels, rewards
                )
            if self.model_args.rank_weight > 0.0:
                if dist.get_rank() == 0:
                    print(
                        {
                            "lm": lm_loss.item(),
                            "rank": rank_loss.item()
                            if not isinstance(rank_loss, int)
                            else rank_loss,
                        }
                    )
                loss = lm_loss + self.model_args.rank_weight * rank_loss
            else:
                loss = lm_loss

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
