# merged_training/encoder.py

from dataclasses import dataclass
from typing import Dict, Optional

import torch
import torch.distributed as dist
from torch import nn, Tensor
from torch.nn import functional as F
from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM
from transformers.file_utils import ModelOutput
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_kbit_training

# Assuming these argument classes are defined in your project
from tevatron.co_retriever.arguments import ModelArguments, DataArguments, TevatronTrainingArguments as TrainingArguments
from transformers import BitsAndBytesConfig


import wandb

import logging
logger = logging.getLogger(__name__)

def cast_lora_params_to_dtype(model: nn.Module, dtype=torch.bfloat16):
    for n, p in model.named_parameters():
        if "lora_" in n:
            p.data = p.data.to(dtype)

@dataclass
class JointEncoderOutput(ModelOutput):
    loss: Optional[Tensor] = None
    contrastive_loss: Optional[Tensor] = None
    revela_loss: Optional[Tensor] = None
    q_reps: Optional[Tensor] = None
    p_reps: Optional[Tensor] = None
    scores: Optional[Tensor] = None

class JointEncoderModel(nn.Module):
    TRANSFORMER_CLS = AutoModelForCausalLM

    def __init__(
        self,
        encoder: PreTrainedModel,
        reference: PreTrainedModel,
        pooling: str = 'cls',
        contrastive_loss_weight: float = 0.5,
        normalize: bool =False,
        temperature: float = 1.0,
        attn_temperature: float = 1.0,
        exclude_diagonal: bool = True,
        disable_v_norm: bool = False,
    ):
        super().__init__()
        self.config = encoder.config
        self.encoder = encoder
        self.reference = reference
        self.pooling = pooling
        self.normalize = normalize
        self.contrastive_loss_weight = contrastive_loss_weight
        self.temperature = temperature
        self.attn_temperature = attn_temperature
        
        self.exclude_diagonal = exclude_diagonal
        self.disable_v_norm = disable_v_norm

        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.is_ddp = dist.is_initialized()
        if self.is_ddp:
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def calculate_attn(self, cosine_similarity: torch.Tensor) -> torch.Tensor:
        """Calculates attention weights for REVELA from similarity scores."""
        scaled_cosine_similarity = cosine_similarity / self.attn_temperature
        if self.exclude_diagonal:
            mask = torch.eye(scaled_cosine_similarity.size(0), device=scaled_cosine_similarity.device).bool()
            scaled_cosine_similarity = scaled_cosine_similarity.masked_fill(mask, float('-inf'))
        return F.softmax(scaled_cosine_similarity, dim=-1)

    # UPDATED: The forward pass now accepts the four outputs from the new collator.
    def forward(
        self, 
        query_contrastive: Dict[str, Tensor], 
        passage_contrastive: Dict[str, Tensor], 
        revela_retriever_input: Dict[str, Tensor],
        revela_llm_input: Dict[str, Tensor]
    ):
        # --- 1. Compute Contrastive Loss ---
        q_reps = self.encode_query(query_contrastive)
        p_reps = self.encode_passage(passage_contrastive)

        if self.training and self.is_ddp:
            q_reps = self._dist_gather_tensor(q_reps)
            p_reps = self._dist_gather_tensor(p_reps)

        scores = torch.matmul(q_reps, p_reps.transpose(0, 1)) / self.temperature
        
        group_size = p_reps.size(0) // q_reps.size(0)
        target = torch.arange(
            0, q_reps.size(0), device=scores.device, dtype=torch.long
        ) * group_size
        
        loss_contrastive = self.cross_entropy(scores, target)

        # --- 2. Compute REVELA (CLM) Loss (Inefficient Loop) ---
        revela_losses = []
        # This loop processes each query group one-by-one, which is very slow.
        for retriever_group, llm_group in zip(revela_retriever_input, revela_llm_input):
            revela_reps_group = self.encode_passage(retriever_group)
            revela_reps_normalized = F.normalize(revela_reps_group, p=2, dim=1)
            
            passage_similarity = torch.matmul(revela_reps_normalized, revela_reps_normalized.transpose(0, 1))
            attn_weights = self.calculate_attn(passage_similarity)
            
            cached_outputs = self.reference(**llm_group, use_cache=True, output_hidden_states=True)
            past_key_values = cached_outputs.past_key_values

            loss_revela_group = self.reference(
                **llm_group,
                inbatch_attn=attn_weights,
                cached_key_values=past_key_values,
                disable_v_norm=self.disable_v_norm,
            ).loss
            revela_losses.append(loss_revela_group)

        # Aggregate the losses from the loop
        if revela_losses:
            loss_revela = torch.stack(revela_losses).mean()
        else:
            loss_revela = torch.tensor(0.0, device=loss_contrastive.device)


        # --- 3. Combine Losses ---
        if self.training and self.is_ddp:
            dist.all_reduce(loss_contrastive, op=dist.ReduceOp.AVG)
            dist.all_reduce(loss_revela, op=dist.ReduceOp.AVG)
        
        total_loss = (self.contrastive_loss_weight * loss_contrastive) + \
                     ((1 - self.contrastive_loss_weight) * loss_revela)

        if self.training:
            # Only log from the main process in a distributed setup
            if not self.is_ddp or self.process_rank == 0:
                wandb.log({
                    "contrastive_loss": loss_contrastive.item(),
                    "revela_loss": loss_revela.item(),
                    "total_loss": total_loss.item(),
                })

        return JointEncoderOutput(
            loss=total_loss,
            contrastive_loss=loss_contrastive,
            revela_loss=loss_revela,
            q_reps=q_reps,
            p_reps=p_reps,
            scores=scores,
        )

    def encode_passage(self, psg):
        raise NotImplementedError('EncoderModel is an abstract class')

    def encode_query(self, qry):
        raise NotImplementedError('EncoderModel is an abstract class')
    
    def compute_similarity(self, q_reps, p_reps):
        return torch.matmul(q_reps, p_reps.transpose(0, 1))

    def compute_loss(self, scores, target):
        return self.cross_entropy(scores, target)
    
    def gradient_checkpointing_enable(self, **kwargs):
        self.encoder.model.gradient_checkpointing_enable()
    
    def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()
        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)
        all_tensors[self.process_rank] = t
        return torch.cat(all_tensors, dim=0)

    # The `build`, `load`, and `save` methods from your original code would follow here.
    # They are omitted for brevity as they are not affected by these logical changes.
    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            train_args: TrainingArguments,
            data_args: DataArguments,
            **hf_kwargs,
    ):


        if model_args.bnb_4bit:
            hf_kwargs['quantization_config'] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
            hf_kwargs.pop('torch_dtype', None)  # <-- important
        
        retriever_model = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
        REFERENCE_BASE_MODEL = model_args.reference_model_name_or_path
        reference_model = cls.TRANSFORMER_CLS.from_pretrained(REFERENCE_BASE_MODEL, **hf_kwargs)

        if model_args.bnb_4bit:
            retriever_model = prepare_model_for_kbit_training(
                retriever_model,
                use_gradient_checkpointing=train_args.gradient_checkpointing
            )
            if reference_model is not retriever_model:
                reference_model = prepare_model_for_kbit_training(
                    reference_model,
                    use_gradient_checkpointing=train_args.gradient_checkpointing
                )

        # Ensure pad_token_id is set
        if retriever_model.config.pad_token_id is None:
            retriever_model.config.pad_token_id = 0
        if reference_model.config.pad_token_id is None:
            reference_model.config.pad_token_id = 0

        if model_args.lora or model_args.retriever_lora_name_or_path:
            if train_args.gradient_checkpointing:
                retriever_model.enable_input_require_grads()
            
            # Apply LoRA to the encoder (retriever)
            if model_args.retriever_lora_name_or_path:
                lora_config = LoraConfig.from_pretrained(model_args.retriever_lora_name_or_path, **hf_kwargs)
                retriever_model = PeftModel.from_pretrained(retriever_model.base_model, model_args.retriever_lora_name_or_path, is_trainable=True)
                
                if model_args.bnb_4bit:
                    cast_lora_params_to_dtype(retriever_model, dtype=torch.bfloat16)
            else:
                lora_config = LoraConfig(
                    base_model_name_or_path=model_args.model_name_or_path,
                    task_type=TaskType.FEATURE_EXTRACTION,
                    r=model_args.lora_r,
                    lora_alpha=model_args.lora_alpha,
                    lora_dropout=model_args.lora_dropout,
                    target_modules=model_args.lora_target_modules.split(','),
                    inference_mode=False
                )
                retriever_model = get_peft_model(retriever_model.base_model, lora_config)
            
            if model_args.bnb_4bit:
                cast_lora_params_to_dtype(retriever_model, dtype=torch.bfloat16)
        
        if train_args.gradient_checkpointing:
            reference_model.enable_input_require_grads()

        # reference encoder_lora_name_or_path
        if model_args.reference_lora_name_or_path:
            lora_config = LoraConfig.from_pretrained(model_args.reference_lora_name_or_path, **hf_kwargs)
            reference_model = PeftModel.from_pretrained(reference_model, model_args.reference_lora_name_or_path, is_trainable=model_args.reference_training)
            if model_args.bnb_4bit:
                cast_lora_params_to_dtype(reference_model, dtype=torch.bfloat16)
            
            if model_args.freeze_reference:
                # The key part: if freeze_reference is True, we freeze all parameters in the reference model.
                for param in reference_model.parameters():
                    param.requires_grad = False
        
        elif model_args.freeze_reference:
            # The key part: if freeze_reference is True, we freeze all parameters in the reference model.
            for param in reference_model.parameters():
                param.requires_grad = False
        elif model_args.lora:
            lora_config = LoraConfig(
                base_model_name_or_path=REFERENCE_BASE_MODEL,
                task_type=TaskType.FEATURE_EXTRACTION,
                r=model_args.lora_r,
                lora_alpha=model_args.lora_alpha,
                lora_dropout=model_args.lora_dropout,
                target_modules=model_args.lora_target_modules.split(','),
                inference_mode=False
            )
            reference_model = get_peft_model(reference_model, lora_config)
            if model_args.bnb_4bit:
                cast_lora_params_to_dtype(reference_model, dtype=torch.bfloat16)

        model = cls(
            encoder=retriever_model,
            reference=reference_model,
            pooling=model_args.pooling,
            contrastive_loss_weight=model_args.contrastive_loss_weight,
            normalize=model_args.normalize,
            temperature=model_args.temperature,
            attn_temperature=model_args.attn_temperature,
            exclude_diagonal=model_args.exclude_diagonal,
            disable_v_norm=model_args.disable_v_norm,
        )

        return model

    @classmethod
    def load(cls,
             model_name_or_path: str,
             pooling: str = 'cls',
             normalize: bool = False,
             retriever_lora_name_or_path: str = None,
             reference_lora_name_or_path: str = None,
             **hf_kwargs):
        """
        This funtion is useually used for inference
        """
        base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
        if base_model.config.pad_token_id is None:
            base_model.config.pad_token_id = 0
        
        if retriever_lora_name_or_path:
            lora_config = LoraConfig.from_pretrained(retriever_lora_name_or_path, **hf_kwargs)
            encoder_lora_model = PeftModel.from_pretrained(base_model.model, retriever_lora_name_or_path, config=lora_config)
            encoder_lora_model = encoder_lora_model.merge_and_unload()
        else:                    
            encoder_lora_model = base_model.model
        
        if reference_lora_name_or_path:
            lora_config = LoraConfig.from_pretrained(reference_lora_name_or_path, **hf_kwargs)
            reference_lora_model = PeftModel.from_pretrained(base_model, reference_lora_name_or_path, config=lora_config)
            reference_lora_model = reference_lora_model.merge_and_unload()
        else:
            reference_lora_model = base_model
        
        model = cls(
            encoder=encoder_lora_model,
            reference=reference_lora_model,
            pooling=pooling,
            normalize=normalize
        )

        return model

    def save(self, output_dir: str):
        self.encoder.save_pretrained(output_dir)
