from dataclasses import dataclass
import eval_glue
from transformers import Trainer, TrainingArguments, TrainerCallback
from collections import defaultdict
import torch
from param import param

class FisherCallback(TrainerCallback):

    def __init__(self, trainer):
        self.trainer = trainer
        self.data_items = len(trainer.train_dataset)
        self.fisher_weights = []

    def on_step_end(self, args: TrainingArguments, state, control, **kwargs):

        self.fisher_weights.append({
            n: p.grad.detach() ** 2
            for n, p in self.trainer.model.named_parameters() 
        })

    def on_epoch_end(self, args: TrainingArguments, state, control, **kwargs):
        
        # 使用defaultdict来初始化result，这样在没有某个键的情况下，返回值为0而不会报错
        result = defaultdict(int)
        for d in self.fisher_weights:
            for key, value in d.items():
                result[key] += value / self.data_items
        self.trainer.fisher_weights = result


@dataclass
class FisherMerge:

    # models_to_merge: list
    names: list
    data_nums: list
    scaling: list = None
    norm_fish_weight: bool = True
    min_fish_weight: float = 1e-6

    def merge(self, ):
        fisher_weight = self.get_coefficient()
        return self.get_merged(fisher_weight)

    def get_coefficient(self, ):
        # TODO: tasker

        def compute_fisher_loss(self, model, inputs, return_outputs=False):
            outputs = model(**inputs)
            logits = outputs.logits # [batch, num_label]
            
            # For regression task
            if logits.shape[-1] == 1:
                # use the label information only in here
                loss = outputs.loss
            
            # For classification / generation task
            else:
                # use detach() to detach from the computation graph
                _probabilities = torch.softmax(logits, dim=-1).detach()
                _log_probabilities = torch.log_softmax(logits, dim=-1)
                # fisher weight will squared these gradients
                labels_expectations = torch.sqrt(_probabilities) * _log_probabilities
                # sum over label classes and batch dimension
                loss = labels_expectations.sum(dim=-1).sum(dim=0)

            return (loss, outputs) if return_outputs else loss

        fisher_weights = []
        for name in self.names:
            model, tokenizer = eval_glue.load_glue_classifier(name, )
            train_dataset = eval_glue.load_glue_dataset(tokenizer, name, split='train').select(range(self.data_nums))
            trainer = Trainer(
                model=model,
                args=TrainingArguments(
                    per_device_train_batch_size=16,
                    num_train_epochs=1,
                    report_to=[], # disable wandb
                ),
                train_dataset=train_dataset, 
                tokenizer=tokenizer,
            )
            setattr(trainer, 'compute_loss', compute_fisher_loss)
            trainer.add_callback(FisherCallback(trainer))
            trainer.train()

            fisher_weights.append(trainer.fisher_weights)
        
        return fisher_weights

    def get_merged(
        self, 
        fisher_weights,
    ):
        para_names = list(self.models_to_merge[0].keys())

        if self.norm_fish_weight:
            def get_norm_across_model(n):
                # (n_models, *weight_shape)
                weight = torch.stack([w[n] for w in fisher_weights])
                dims = [i for i in range(1, weight.dim())]
                return torch.norm(weight, dim=dims)

            models_fisher_norm = torch.stack([
                get_norm_across_model(n)
                for n in para_names
            ], dim=1)
            models_fisher_norm = torch.norm(models_fisher_norm, dim=1)

        # for each para
        def fisher_process(ps):
            
            ps, fw = torch.stack(ps[:-1], dim=0), torch.stack(ps[-1], dim=0)
            shape = (-1, *[1 for _ in range(ps.dim() - 1)])
            fw = fw + self.min_fish_weight

            _coef = self.scaling.reshape(shape)
            if self.norm_fish_weight:
                mod_fish_norm = 1.0 / (mod_fish_norm + self.min_fish_weight)
                norm_mod_fish_norm = mod_fish_norm / mod_fish_norm.sum()
                norm_mod_fish_norm = norm_mod_fish_norm.reshape()
                _coef = _coef * norm_mod_fish_norm

            numerator = (_coef * fw * ps).sum(dim=0)
            denominator = (_coef * fw).sum(dim=0)
            ps = numerator / denominator
            return ps
        
        merged_param = param.vectorize_reduce(
            fisher_process,
            self.models_to_merge + [fisher_weights],
        )
        return merged_param
        

if __name__ == "__main__":
    num_fisher_examples_range = [256, 512, 1024, 2048]
    fisher_scaling_coefficient_range = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]

