import torch
from transformers import LlamaForCausalLM
from peft import PeftModel

from trainer.unlearn.base import UnlearnTrainer
from model.uld import get_assistant_model
import logging

logger = logging.getLogger("ULD Trainer")

class ULD(UnlearnTrainer):
    def __init__(
        self,
        lora,
        num_layers=8,
        retain_loss_weight=0.8,
        *args,
        **kwargs
    ):
        self.lora=lora
        self.num_layers = num_layers
        self.retain_loss_weight = retain_loss_weight

        # Extract the base model from kwargs before super init, call init using the wrapped model
        if "model" in kwargs:
            original_model = kwargs.pop("model")
        else:
            # Also check if model was passed as positional argument
            if len(args) > 0:
                original_model = args[0]
                args = args[1:]  # Remove model from args
            else:
                raise Exception("Couldn't parse model - no model provided")
        
        assert isinstance(original_model, LlamaForCausalLM), "ULD currently only supports Llama models."


        logger.info(f"Creating assistant model with the LoRA Config: {self.lora}")
        assistant_model = get_assistant_model(original_model, self.lora, self.num_layers)

        kwargs["model"] = assistant_model
        super().__init__(*args, **kwargs)

        logger.info("Initialized ULD Trainer")
        trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.model.parameters())
        logger.info(f"Model trainable parameters: {trainable} of {total} ({trainable/total:.4%})")
        
    def compute_retain_loss(self, model, retain_inputs):

        # Compute KL between model outputs and uniform distribution
        retain_outputs = model(**retain_inputs)
        logits = retain_outputs.logits
        vocab_size = logits.size(-1)
        uniform_dist = torch.full_like(logits, 1.0 / vocab_size)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        retain_loss = torch.nn.functional.kl_div(
            log_probs, uniform_dist, reduction='batchmean'
        )
        return retain_loss


    def compute_loss(self, model, inputs, return_outputs=False):
        forget_inputs = inputs["forget"]
        forget_inputs = {
            "input_ids": forget_inputs["input_ids"],
            "attention_mask": forget_inputs["attention_mask"],
            "labels": forget_inputs["labels"],
        }

        # Assistant model learns to memorize forget set while approaching uniform on retain set
        forget_outputs = model(**forget_inputs)
        forget_loss = forget_outputs.loss

        retain_inputs = inputs["retain"]
        retain_inputs = {
            "input_ids": retain_inputs["input_ids"],
            "attention_mask": retain_inputs["attention_mask"],
            "labels": retain_inputs["labels"],
        }
        retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss =  forget_loss + self.retain_loss_weight * retain_loss
        return (loss, forget_outputs) if return_outputs else loss
    
    def save_model(self, output_dir):
        # Merge LoRA weights if we’re dealing with a PeftModel so we save the full base model.
        model_to_save = self.model
        if isinstance(model_to_save, PeftModel):
            logger.info("Merging PEFT model for saving")
            model_to_save = model_to_save.merge_and_unload()

        logger.info(f"Saving merged model to {output_dir}")
        model_to_save.save_pretrained(output_dir)