import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel, Trainer
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.tv_utils import *
from torch.nn import functional as F
from torch import nn
from transformers.pytorch_utils import Conv1D
from tqdm import tqdm
from src.dataset import FewShotDataset
from src.args import parse_arguments
from src.data_arguments import data_args
from src.trainer import Trainer
import re


root = "/data/common/lm-bff"
task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
# task_list = ["SST-2", "cr"]
ROBERTA_PARAM = 163941810
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

model_fn = RobertaForPromptFinetuning
# modelname = "roberta-base"
modelname = "roberta-large"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
        
def get_model_path(task):
    if modelname == 'roberta-base':
        return "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"
    elif modelname == 'roberta-large':
        return "/data/common/lm-bff/ckpt_paths/large_log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"

def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    return model

tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)


def filter_modules_by_regex(base_module, include_patterns, include_type):
    modules = {}
    for name, module in base_module.named_modules():
        valid_name = not include_patterns or any(
            [re.match(patt, name) for patt in include_patterns]
        )
        valid_type = not include_type or any(
            [isinstance(module, md_cls) for md_cls in include_type]
        )
        if valid_type and valid_name:
            modules[name] = module
    return modules


def compute_grams(finetuned_model, train_dataloader):
    covs = {}
    xn = {}

    def get_grams(name):
        def hook(module, input, output):
            """
            Note: adhere to signature of hook functions
            """
            x = input[0].detach()  # $[b,t,h]
            x = x.view(-1, x.size(-1))
            xtx = torch.matmul(x.transpose(0, 1), x)  # [h,h]
            if name not in covs:
                covs[name] = xtx / x.size(0)
                xn[name] = x.size(0)
            else:
                covs[name] = (covs[name] * xn[name] + xtx) / (x.size(0) + xn[name])
                xn[name] += x.size(0)

        return hook

    model = finetuned_model.to(device)
    linear_modules = filter_modules_by_regex(
        model, None, [nn.Linear, nn.Conv1d, Conv1D]
    )
    # print("Linear modules: {}".format(linear_modules))
    handles = []
    for name, module in linear_modules.items():
        handle = module.register_forward_hook(get_grams(name))
        handles.append(handle)

    # mark cov modules as special
    covs["meta_info"] = {
        "conv1d": [
            n
            for n, m in filter_modules_by_regex(
                model, None, [nn.Conv1d, Conv1D]
            ).items()
        ]
    }

    n_step = 1000
    total = n_step if n_step > 0 else len(train_dataloader)
    for step, data in tqdm(
        enumerate(train_dataloader), total=total, desc="Computing gram matrix"
    ):
        if n_step > 0 and step == n_step:
            break
        # print(inputs['labels'])
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        mask_pos = data['mask_pos'].to(device)
        
        _ = model(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos)

    for handle in handles:
        handle.remove()

    return covs


def to_diag(cov_mat):
    mask = torch.diag(torch.ones(cov_mat.size(0))).to(cov_mat.device)
    diag_cov_mat = mask * cov_mat
    return diag_cov_mat


def reduce_non_diag(cov_mat, a):
    diag_weight = torch.diag(torch.ones(cov_mat.size(0)) - a).to(cov_mat.device)
    non_diag_weight = torch.zeros_like(diag_weight).fill_(a)
    weight = diag_weight + non_diag_weight
    ret = cov_mat * weight
    return ret


def regmean_merge(all_params, all_grams):
    avg_params = {}
    n_model = len(all_grams)
    for name in all_params:
        h_avged = False
        if name.endswith('.weight'):
            # print(f'Regmean: {name}')
            module_name = name[:-len('.weight')]
            if module_name in all_grams[0]:
                gram_m_ws, grams = [], []

                for model_id, model_grams in enumerate(all_grams):
                    param_grams = model_grams[module_name].cpu()

                    # for roberta we dont need this; but it is important for deberta and t5
                    param_grams = reduce_non_diag(param_grams, a=0.9)

                    param = all_params[name][model_id]
                    gram_m_ws.append(torch.matmul(param_grams, param.transpose(0,1)))
                    grams.append(param_grams)
                sum_gram = sum(grams)
                sum_gram_m_ws = sum(gram_m_ws)
                sum_gram_inv = torch.inverse(sum_gram)
                wt = torch.matmul(sum_gram_inv, sum_gram_m_ws)
                w = wt.transpose(0,1)
                avg_params[name] = w
                h_avged = True
        if not h_avged: # if not averaged with regmean, then do simple avg
            avg_params[name] = torch.stack(all_params[name],0).mean(0)
           
    return avg_params


test_datasets = {}
for task in task_list:
    print(task)

    test_datasets[task] = (
        FewShotDataset(data_args[task], tokenizer=tokenizer, cache_dir="/data/common/lm-bff/model_files", mode="dev", use_demo=False)
    )


all_grams, finetuned_models = [], []
for task in task_list:
    model_path = get_model_path(task)
    finetuned_model = initialize_model(model_path)
    finetuned_models.append(finetuned_model.state_dict())

    dataset = test_datasets[task]
    finetuned_model.model_args = parse_arguments()
    finetuned_model.model_args.use_lm_head = True
    finetuned_model.label_word_list = torch.tensor(dataset.label_word_list).long().cuda()
    trainer = Trainer(model=finetuned_model, eval_dataset=dataset)
    data_collator = trainer._get_collator_with_removed_columns(trainer.data_collator, description="evaluation")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, collate_fn=data_collator, sampler=trainer._get_eval_sampler(dataset)) # 64 for roberta-base

    grams = compute_grams(finetuned_model, dataloader)
    all_grams.append(grams)


params = {}
for local_model in finetuned_models:
    n2p = {k: v for k,v in local_model.items()}
    merge_param_names = []
    for n in n2p:
        frozen = ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.decoder.bias']
        if n not in frozen:
            merge_param_names.append(n)
    for n in merge_param_names:
        if n not in params:
            params[n] = []
        params[n].append(n2p[n])

merged_state_dict = regmean_merge(params, all_grams)
if modelname == 'roberta-base':
    path = root+f"/ckpt_paths/merged_models/regmean_all-merged-roberta-base-2-2e-5"
elif modelname == 'roberta-large':
    path = root+f"/ckpt_paths/merged_models/large/regmean_all-merged-{modelname}-2-2e-5"
model_temp = initialize_model(modelname)
model_temp.load_state_dict(merged_state_dict, strict=False)
model_temp.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)