
import os
import torch
import torch.nn as nn

from torch.utils.data import Sampler
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
from torch.nn import Module
import torch.nn.functional as F
from transformers import Trainer
from typing import List, Optional
from llava_trainer import LLaVATrainer
from transformers.trainer import (
    is_sagemaker_mp_enabled,
    get_parameter_names,
    has_length,
    ALL_LAYERNORM_LAYERS,
    logger,
)

def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_per_token_logp=False, return_all=False, tokenizer=None) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    assert logits.shape[:-1] == labels.shape, f'logits.shape[:-1]={logits.shape[:-1]}, labels.shape={labels.shape}'

    
    
    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    labels[labels == -300] = -100
    loss_mask = (labels != -100)

    labels[labels == -100] = 0
    
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2,
                                    index=labels.unsqueeze(2)).squeeze(2)

    
    log_prob = (per_token_logps * loss_mask).sum(-1)
    average_log_prob = log_prob / loss_mask.sum(-1)
    
    if return_per_token_logp:
        return per_token_logps

    if return_all:
        return per_token_logps, log_prob, average_log_prob,loss_mask

    return log_prob, average_log_prob




def pad_or_truncate_logp_list(ref_per_token_logp, target_shape_tensor, pad_value=-20.0):
    B, L = target_shape_tensor.shape
    result = torch.full((B, L), pad_value, dtype=target_shape_tensor.dtype, device=target_shape_tensor.device)

    for i in range(B):
        row = ref_per_token_logp[i]
        length = min(len(row), L)
        result[i, :length] = torch.tensor(row[:length], dtype=target_shape_tensor.dtype)

    return result


def estimate_pathwise_kl(ref_per_token_logp: torch.Tensor,
                         per_token_logp: torch.Tensor,
                         attention_mask: torch.Tensor = None,):
    ref_per_token_logp = pad_or_truncate_logp_list(ref_per_token_logp, per_token_logp)

    
    delta = ref_per_token_logp - per_token_logp
    per_token_kl = torch.exp(delta) - delta - 1

    if attention_mask is not None:
        per_token_kl_masked = per_token_kl * attention_mask
        mean_kl = (per_token_kl_masked.sum(dim=1) / attention_mask.sum(dim=1) ).mean()
    else:
        mean_kl = per_token_kl.mean()

    return mean_kl





def get_logps(data_dict, model, args, is_llava15=False ,attention=False):
    

    input_ids = data_dict.pop('input_ids')
    labels = data_dict.pop('labels')
    images = data_dict.pop('images')
    attention_mask = data_dict.pop('attention_mask')
    
    if args.contrastive:
        index =data_dict.pop('index')
        ref_logp = data_dict.pop('ref_logp')
        ref_avg_logp = data_dict.pop('ref_avg_logp')
        ref_per_token_logp_wo_img = data_dict.pop('ref_per_token_logp_wo_img') 
        ref_sub = data_dict.pop('ref_sub')
        ref_logp_wo_img = data_dict.pop('ref_logp_wo_img')

    sub_=None
    if is_llava15:
        if args.lora_enable:
            (
                _,
                _,
                attention_mask,
                _,
                inputs_embeds,
                labels
            ) = model.base_model.prepare_inputs_labels_for_multimodal(
                input_ids=input_ids,
                position_ids=None,
                attention_mask=attention_mask,
                past_key_values=None,
                labels=labels,
                images=images,
            )
        else:
            (
                _,
                _,
                attention_mask,
                _,
                inputs_embeds,
                labels
            ) = model.prepare_inputs_labels_for_multimodal(
                input_ids=input_ids,
                position_ids=None,
                attention_mask=attention_mask,
                past_key_values=None,
                labels=labels,
                images=images,
            )
            
        
        output = model.forward(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
            **data_dict,
            contrastive_grad= args.contrastive_grad,
            contrastive_logits=args.contrastive_logits,
            output_attention_statistics=attention,
        )
        per_token_logps, log_prob, average_log_prob,loss_mask = get_batch_logps(
            output.logits, labels, return_all=True)
        
        loss_sft = output.loss
        img_attn_score = None
        txt_attn_score = None
        attentions= output.attentions
        

        if attentions is not None:
            img_attn_list = [layer_attn[0] for layer_attn in attentions if layer_attn[0] is not None]  # list of [batch_size]
            txt_attn_list = [layer_attn[1] for layer_attn in attentions if layer_attn[1] is not None]  # list of [batch_size]

            img_attn_score = torch.stack(img_attn_list, dim=0).transpose(0, 1)  # [batch_size, layer_num]
            txt_attn_score = torch.stack(txt_attn_list, dim=0).transpose(0, 1)  # [batch_size, layer_num]
            

            img_attn_score = img_attn_score.mean(dim=1)  # shape: [batch_size]
            txt_attn_score = txt_attn_score.mean(dim=1)  # shape: [batch_size]


        if args.contrastive:
            (sub,uncertainty,labels_,_,labels_0)=output.weight
            sub_ = torch.stack([
                abs(sub[i][labels_[i] != -100]).mean()  
                for i in range(sub.size(0))      
            ])
            ref_sub = torch.tensor(ref_sub, dtype=sub_.dtype, device=sub_.device)

            sub_=sub_-ref_sub

            (_,_,_,logits_wo_img ,_)=output.weight
            per_token_logp_wo_img, log_prob_wo_img, average_log_prob_wo_img,loss_mask_wo_img = get_batch_logps(logits=logits_wo_img, labels = labels, return_all=True)             
            ref_avg_logp_wo_img = ref_logp_wo_img / loss_mask_wo_img.sum(-1)
            kl_wo_img= estimate_pathwise_kl(ref_per_token_logp_wo_img, per_token_logp_wo_img, attention_mask=loss_mask)
            return log_prob,average_log_prob, ref_logp,sub_,ref_logp_wo_img, ref_avg_logp ,log_prob_wo_img, average_log_prob_wo_img, loss_sft,kl_wo_img,ref_avg_logp_wo_img
        else:
            return log_prob, loss_sft, img_attn_score, txt_attn_score
    else:
        print('no implementation !')
  
    


def collect_metrics_conctrastive(metrics, task,
                               rewards, 
                                logp,
                                ref_logp,
                                sub,
                                margins_wo_img,
                                kl_wo_img,
                                rewards_margins_mean,
                                contrastive_margins_mean,
                               preprocess_func,
                               ):
    t = task
    metrics = {}
    metrics[f'rewards_{t}'] = preprocess_func(rewards)
    metrics[f'logps_{t}'] = preprocess_func(logp)
    metrics[f'logps_{t}/ref'] = preprocess_func(ref_logp)
    metrics[f'rewards_{t}/sub'] = preprocess_func(
        sub)
    metrics[f'rewards_{t}/contrastive_margins'] = preprocess_func(
        margins_wo_img)
    metrics[f'rewards_{t}/rewards_margins_mean'] = preprocess_func(
        rewards_margins_mean)
    metrics[f'rewards_{t}/contrastive_margins_mean'] = preprocess_func(
        contrastive_margins_mean)
    
    metrics[f'kl_wo_img_{t}'] = preprocess_func(
        kl_wo_img)
    return metrics


def split_to_even_chunks(indices, lengths, num_chunks):
    """
    Split a list of indices into `chunks` chunks of roughly equal lengths.
    """

    if len(indices) % num_chunks != 0:
        return [indices[i::num_chunks] for i in range(num_chunks)]

    num_indices_per_chunk = len(indices) // num_chunks

    chunks = [[] for _ in range(num_chunks)]
    chunks_lengths = [0 for _ in range(num_chunks)]
    for index in indices:
        shortest_chunk = chunks_lengths.index(min(chunks_lengths))
        chunks[shortest_chunk].append(index)
        chunks_lengths[shortest_chunk] += lengths[index]
        if len(chunks[shortest_chunk]) == num_indices_per_chunk:
            chunks_lengths[shortest_chunk] = float("inf")

    return chunks


def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    assert all(l != 0 for l in lengths), "Should not have zero length."
    if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
        # all samples are in the same modality
        return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
    mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
    lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])

    mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
    lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
    megabatch_size = world_size * batch_size
    mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
    lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]

    last_mm = mm_megabatches[-1]
    last_lang = lang_megabatches[-1]
    additional_batch = last_mm + last_lang
    megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
    megabatch_indices = torch.randperm(len(megabatches), generator=generator)
    megabatches = [megabatches[i] for i in megabatch_indices]

    if len(additional_batch) > 0:
        megabatches.append(sorted(additional_batch))

    return [i for megabatch in megabatches for i in megabatch]


def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    indices = torch.randperm(len(lengths), generator=generator)
    megabatch_size = world_size * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
    megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]

    return [i for megabatch in megabatches for batch in megabatch for i in batch]

class LengthGroupedSampler(Sampler):
    r"""
    Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
    keeping a bit of randomness.
    """

    def __init__(
        self,
        batch_size: int,
        world_size: int,
        lengths: Optional[List[int]] = None,
        generator=None,
        group_by_modality: bool = False,
    ):
        if lengths is None:
            raise ValueError("Lengths must be provided.")

        self.batch_size = batch_size
        self.world_size = world_size
        self.lengths = lengths
        self.generator = generator
        self.group_by_modality = group_by_modality

    def __len__(self):
        return len(self.lengths)

    def __iter__(self):
        if self.group_by_modality:
            indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
        else:
            indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
        return iter(indices)








class LLaVA15_LBR_Trainer(LLaVATrainer):

    def __init__(self, *args, ref_model=None, **kwargs):
        """
        Initializes a custom trainer that includes a reference model (ref_model).
        z
        Args:
            ref_model (nn.Module): Reference model used for preference comparisons or log-probability calculations.
            *args, **kwargs: Additional arguments for the base Trainer class.
        """

        super().__init__(*args, **kwargs)
        

        self.ref_model = ref_model

        if self.ref_model is not None:
            self.ref_model.eval()  

        self.SFT_weight = float(os.environ.get('SFT_weight', 0.0))
        self.SUB_weight = float(os.environ.get('SUB_weight', 0.0))
        self.L1_weight = float(os.environ.get('L1_weight', 0.0))
        self.L2_weight = float(os.environ.get('L2_weight', 0.0))
        self.KL_weight = float(os.environ.get('KL_weight', 0.0))
        self.contrastive_weight = float(os.environ.get('CONTRASTIVE_weight', 0.0))
        self.L1_weight_mean = float(os.environ.get('L1_weight_mean', 0.0))
        self.L2_weight_mean = float(os.environ.get('L2_weight_mean', 0.0))
        self.attention_weight = float(os.environ.get('attention_weight', 0.0))

        print("SFT_weight:", self.SFT_weight
              , "SUB_weight:", self.SUB_weight
              , "L1_weight:", self.L1_weight
              , "L2_weight:", self.L2_weight
              , "KL_weight:", self.KL_weight
              , "contrastive_weight:", self.contrastive_weight
              , "L1_weight_mean:", self.L1_weight_mean
              , "L2_weight_mean:", self.L2_weight_mean
              , "attention_weight:", self.attention_weight)


    
    def compute_loss(self, model: Module, inputs: dict):
        if self.args.past_index >= 0:
            raise NotImplementedError

        def gather_and_do_mean(x):
            return self._nested_gather(x.mean()).mean().item()
        data_dict = inputs
        sub=0.
        if self.args.contrastive_logits:
            log_prob,average_log_prob, ref_logp,sub,ref_logp_wo_img, ref_avg_logp ,log_prob_wo_img, \
            average_log_prob_wo_img ,loss_sft,kl_wo_img,ref_avg_logp_wo_img = get_logps(
            data_dict, model, self.args, is_llava15=True,attention=False)
        
            contrastive_margins=ref_logp_wo_img-log_prob_wo_img 
            rewards_margins=log_prob-ref_logp
            contrastive_margins_mean = ref_avg_logp_wo_img - average_log_prob_wo_img
            rewards_margins_mean = average_log_prob - ref_avg_logp
            
            
            loss = self.SFT_weight * loss_sft \
                    - self.contrastive_weight * average_log_prob_wo_img.mean()\
                    - self.SUB_weight* F.logsigmoid(self.args.lambda_*(rewards_margins+contrastive_margins)).mean() \
                    + self.L1_weight * contrastive_margins.abs().mean() \
                    + self.L2_weight * contrastive_margins.pow(2).mean() \
                    + self.L1_weight_mean * contrastive_margins_mean.abs().mean() \
                    + self.L2_weight_mean * contrastive_margins_mean.pow(2).mean() \
                    + self.KL_weight * kl_wo_img.mean() \
                    
            metrics = {}
            
            t= 'train'
            metrics = collect_metrics_conctrastive(metrics, 
                                                    t, 
                                                    rewards=rewards_margins, 
                                                    logp=log_prob,
                                                    ref_logp=ref_logp,
                                                    sub=sub,
                                                    margins_wo_img=contrastive_margins,
                                                    kl_wo_img=kl_wo_img,
                                                    rewards_margins_mean=rewards_margins_mean,
                                                    contrastive_margins_mean=contrastive_margins_mean,
                                                    preprocess_func=gather_and_do_mean)
        else:
            print('no implementation !')

        self.log(metrics)

        return loss