import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel, Trainer
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.tv_utils import *
from torch.nn import functional as F
from tqdm import tqdm
from src.dataset import FewShotDataset
from src.args import parse_arguments
from src.data_arguments import data_args
from src.trainer import Trainer


root = "/data/common/lm-bff"
task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
# task_list = ["SST-2", "cr"]
ROBERTA_PARAM = 163941810
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

model_fn = RobertaForPromptFinetuning
# modelname = "roberta-base"
modelname = "roberta-large"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
        
def get_model_path(task):
    if modelname == 'roberta-base':
        return "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"
    elif modelname == 'roberta-large':
        return "/data/common/lm-bff/ckpt_paths/large_log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"

def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    return model

tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)


def compute_fisher(train_dataloader, finetuned_model, fisher_variant='hard'):
    model = finetuned_model.to(device)
    model.train()
    fisher = {}
    n_b = 0

    n_step = total  = len(train_dataloader)

    for step, data in tqdm(enumerate(train_dataloader), total=total):
        if n_step > 0 and step == n_step:
            break

        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        mask_pos = data['mask_pos'].to(device)

        logits = model(input_ids, attention_mask, mask_pos)[0]
        n_b += 1

        # computer empirical fisher
        if fisher_variant == "hard":
            log_probs = torch.log_softmax(logits, -1)
            _, target_labels = logits.max(-1)
            nll_loss = F.nll_loss(log_probs, target_labels)
            model.zero_grad()
            nll_loss.backward()
            b_n2fisher = collect_squared_gradients(model)
        elif fisher_variant == "soft":
            probs = torch.softmax(logits, -1).detach()  # [b,c]
            log_probs = torch.log_softmax(logits, -1)
            num_labels = probs.size(-1)
            nll_losses = []
            for label in range(num_labels):
                target = (
                    torch.full(probs.size()[:-1], label).long().to(probs.device)
                )
                nll_loss = F.nll_loss(log_probs, target, reduction="none")
                nll_losses.append(nll_loss)
            nll_losses = torch.stack(nll_losses, -1)  # [b,c]
            weighted_nll_losses = probs * nll_losses
            mean_nll_loss = weighted_nll_losses.sum(-1).mean()
            model.zero_grad()
            mean_nll_loss.backward()
            b_n2fisher = collect_squared_gradients(model)

        for n, f in b_n2fisher.items():
            if n not in fisher:
                fisher[n] = f
            else:
                fisher[n] += f
    assert n_b
    for n, f in fisher.items():
        fisher[n] = f / n_b
    return fisher


def collect_squared_gradients(model):
    n2fisher = {}
    for n, p in model.named_parameters():
        if p.grad is not None:
            n2fisher[n] = p.grad.detach() ** 2
    return n2fisher


def fisher_weighted_average(all_params, fisher_weights):
    model_coeffs = torch.ones(len(task_list)) * 0.3
    avg_params = {}

    for n, params in all_params.items():
        params = torch.stack(params)  # [N, *]
        fisher = (
            torch.stack([x[n] for x in fisher_weights])
            + 1.0e-10 #self.merger_config.fisher_smooth
        ).to(params.device)  # [N, *]

        coeff = model_coeffs.view(-1, *[1 for _ in range(params.dim() - 1)]).to(
            params.device
        )

        sum_p = params * fisher * coeff
        sum_p = sum_p.sum(0)

        denom = (fisher * coeff).sum(0)

        avg_p = sum_p / denom
        avg_params[n] = avg_p

    return avg_params


test_datasets = {}
for task in task_list:
    print(task)

    test_datasets[task] = (
        FewShotDataset(data_args[task], tokenizer=tokenizer, cache_dir="/data/common/lm-bff/model_files", mode="dev", use_demo=False)
    )
    # if task == "qqp":
    #     test_datasets[task] = (
    #         FewShotDataset(data_args[task], tokenizer=tokenizer, cache_dir="/data/common/lm-bff/model_files", mode="dev", use_demo=False)
    #     )
    # else:
    #     test_datasets[task] = (
    #         FewShotDataset(data_args[task], tokenizer=tokenizer, cache_dir="/data/common/lm-bff/model_files", mode="test", use_demo=False)
    #     )


fisher_weights, finetuned_models = [], []
for task in task_list:
    dataset = test_datasets[task]
    model_path = get_model_path(task)
    finetuned_model = initialize_model(model_path)
    finetuned_models.append(finetuned_model.state_dict())

    finetuned_model.model_args = parse_arguments()
    finetuned_model.model_args.use_lm_head = True
    finetuned_model.label_word_list = torch.tensor(dataset.label_word_list).long().cuda()
    trainer = Trainer(model=finetuned_model, eval_dataset=dataset)
    data_collator = trainer._get_collator_with_removed_columns(trainer.data_collator, description="evaluation")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=data_collator, sampler=trainer._get_eval_sampler(dataset)) # base 64

    fisher = compute_fisher(dataloader, finetuned_model, fisher_variant='hard')
    fisher_weights.append(fisher)

params = {}
for local_model in finetuned_models:
    n2p = {k: v for k,v in local_model.items()}
    merge_param_names = []
    for n in n2p:
        frozen = ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.decoder.bias']
        if n not in frozen:
            merge_param_names.append(n)
    for n in merge_param_names:
        if n not in params:
            params[n] = []
        params[n].append(n2p[n])

merged_state_dict = fisher_weighted_average(params, fisher_weights)
if modelname == 'roberta-base':
    path = root+f"/ckpt_paths/merged_models/fisher_0.3_all-merged-roberta-base-2-2e-5"
elif modelname == 'roberta-large':
    path = root+f"/ckpt_paths/merged_models/large/fisher_0.3_all-merged-{modelname}-2-2e-5"
model_temp = initialize_model(modelname)
model_temp.load_state_dict(merged_state_dict, strict=False)
model_temp.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)