from modules.Adv import FreeLB, GaussianNoise

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer
import numpy as np
import wandb
from typing import List, Dict, Union, Any


# Custom trainer for weak-to-strong finetuning with auxiliary loss
class WTS_Trainer(Trainer):
    def __init__(self, enable_aux_loss=False, task=None, alpha_max=0.5, burn_in_period=0.2, task2class=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.enable_aux_loss = enable_aux_loss
        self.task = task
        self.alpha_max = alpha_max
        self.burn_in_period = burn_in_period
        self.task2class = task2class

    def compute_aux_loss(self, model, inputs, outputs):
        logits = outputs.logits

        # Get labels from logits by thresholding based on class weights
        class_weights = self.task2class[self.task]
        num_instances = logits.size(0)
        logit_labels = [None] * num_instances
        remaining_fraction = 1.0

        for label_str in class_weights:
            class_fraction = min(class_weights[label_str] / remaining_fraction, 1.0)
            label = int(label_str)

            # Assign label to instances with logits above threshold
            remaining_logits = np.array(
                [logits[i, label].item() for i in range(num_instances) if logit_labels[i] is None]
            )
            threshold = np.quantile(remaining_logits, 1.0 - class_fraction)
            for i in range(num_instances):
                if logit_labels[i] is None and logits[i, label] > threshold:
                    logit_labels[i] = label

            remaining_fraction -= class_weights[label_str]

        # Assign last label to remaining logit_labels
        for i in range(num_instances):
            if logit_labels[i] is None:
                logit_labels[i] = label

        # Compute auxiliary loss
        logit_labels = torch.tensor(logit_labels, device=logits.device)
        aux_loss = F.cross_entropy(logits, logit_labels)

        return aux_loss


class WTS_Trainer_Naive(WTS_Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = F.cross_entropy(logits, labels)

        # Auxilary confidence loss
        if self.enable_aux_loss:
            # Warmup alpha from 0 to alpha_max over the first 20 percent of the training
            alpha = self.alpha_max * min(1.0, self.state.global_step / (self.burn_in_period * self.state.max_steps))
            aux_loss = self.compute_aux_loss(model, inputs, outputs)
            loss = ((1 - alpha) * loss) + (alpha * aux_loss)

        return (loss, outputs) if return_outputs else loss


class WTS_Trainer_GaussianNoise(WTS_Trainer):
    def __init__(self, noise_std, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.gaussian_noise = GaussianNoise(noise_std)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        input_ids = inputs["input_ids"]
        attention_masks = inputs["attention_mask"]
        
        if isinstance(model, nn.DataParallel):
            word_embeddings = model.module.get_input_embeddings()
        else:
            word_embeddings = model.get_input_embeddings()

        embeds_init = word_embeddings(input_ids)
        embeds_perturbed = self.gaussian_noise.attack(embeds_init)
        outputs = model(inputs_embeds=embeds_perturbed, attention_mask=attention_masks, labels=labels)
        logits = outputs.logits
        loss = F.cross_entropy(logits, labels)

        # Auxilary confidence loss
        if self.enable_aux_loss:
            # Warmup alpha from 0 to alpha_max over the first 20 percent of the training
            alpha = self.alpha_max * min(1.0, self.state.global_step / (self.burn_in_period * self.state.max_steps))
            aux_loss = self.compute_aux_loss(model, inputs, outputs)
            loss = ((1 - alpha) * loss) + (alpha * aux_loss)

        return (loss, outputs) if return_outputs else loss


class WTS_Trainer_FreeLB(WTS_Trainer):
    def __init__(self, adv_K=2, adv_lr=1e-1, adv_init_mag=6e-1, adv_max_norm=0., adv_norm_type='l2', *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.freelb = FreeLB(
            adv_K=adv_K,
            adv_lr=adv_lr,
            adv_init_mag=adv_init_mag,
            adv_max_norm=adv_max_norm,
            adv_norm_type=adv_norm_type, 
            hf_accelerator=self.accelerator
        )

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.freelb.attack(model, inputs)  # backward() inserted

        del inputs

        return loss.detach() / self.args.gradient_accumulation_steps
