import torch
import tqdm
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.tv_utils import *
from src.dataset import FewShotDataset
from src.data_arguments import data_args
from src.trainer import Trainer
from src.args import parse_arguments


root = "/data/common/lm-bff"
device = "cuda" if torch.cuda.is_available() else "cpu"
ROBERTA_PARAM = 163941810
model_fn = RobertaForPromptFinetuning
modelname = "roberta-base"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)

class TaskVector():
    def __init__(self, pretrained_checkpoint, finetuned_checkpoint):
        with torch.no_grad():
            pretrained_state_dict = pretrained_checkpoint.state_dict()
            finetuned_state_dict = finetuned_checkpoint.state_dict()
            self.vector = {}
            for key in pretrained_state_dict:
                # if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]:
                #     continue
                if key in names:
                    self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key]


def select_trainable_parameters(model):
    params = {}
    for n, p in model.named_parameters():
        if 'encoder.layer' in n:
            params[n] = p
                    
    return params

def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)

def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    model.tokenizer = tokenizer
    return model

def select_trainable_parameters(model):
    params = {}
    for n, p in model.named_parameters():
        if 'encoder.layer' in n:
            params[n] = p
                    
    return params


def get_task_vector(pretrained_model, finetuned_model):
    pretrained_params = state_dict_to_vector(pretrained_model.state_dict())
    model_params = state_dict_to_vector(finetuned_model.state_dict())
    
    return model_params - pretrained_params


class AdaMerging(torch.nn.Module):
    def __init__(self, paramslist, model, names):
        super(AdaMerging, self).__init__()
        self.paramslist = paramslist
        self.model = model
        self.names = names
        self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1)
        prior = 0.3
        rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior  # (1 * tasks)
        self.lambdas_raw = torch.nn.Parameter(rlambdas)

        self.model.model_args = parse_arguments()
        self.model.model_args.use_lm_head = True

    def lambdas(self):
        task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
        lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1)
        return lambdass

    def collect_trainable_params(self):
        return [self.lambdas_raw]
    
    def get_model(self):
        alph = self.lambdas()
        # params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
        # params = tuple(p.cuda() for p in params)
        # load_weights(self.model, self.names, params)

        new_state_dict = {}
        for i in range(len(self.names)):
            param_i = torch.zeros_like(self.paramslist[0][i])
            for j in range(len(task_list)+1):
                param_i += alph[i][j] * self.paramslist[j][i]
            new_state_dict[self.names[i]] = param_i

        return new_state_dict

    def forward(self, input_ids, attention_mask, mask_pos):
        alph = self.lambdas()
        params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
        params = tuple(p.cuda() for p in params)
        
        load_weights(self.model, self.names, params)
        # self.model = self.get_model()
        out = self.model(input_ids, attention_mask, mask_pos)
        
        return out


task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
# task_list = ["SST-2", "cr"]
ckpt_pth = "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/"
model_path_list = [ckpt_pth+f"{task}-prompt-64-0-roberta-base-2-2e-5" for task in task_list]

pretrained_model = initialize_model(modelname)
finetuned_models = [initialize_model(model_path) for model_path in model_path_list]
model = initialize_model(modelname)
_, names = make_functional(model)

paramslist = []
paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model.state_dict().items())] # pretrain
# task_vectors = [TaskVector(pretrained_model, finetuned_model) for finetuned_model in finetuned_models]
# paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items())  for i, tv in enumerate(task_vectors)] # task vectors
# paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in select_trainable_parameters(pretrained_model).items())] # pretrain
task_vectors = [TaskVector(pretrained_model, finetuned_model) for finetuned_model in finetuned_models]
paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items())  for i, tv in enumerate(task_vectors)] # task vectors

torch.cuda.empty_cache()
adamerging_mtl_model = AdaMerging(paramslist, model, names)

print('init lambda:')
print(adamerging_mtl_model.lambdas())
print('collect_trainable_params:')
print(list(adamerging_mtl_model.collect_trainable_params()))

epochs = 500
optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.)

def softmax_entropy(x):
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

test_datasets = {}
for task in task_list:
    print(task)
    if task == "qqp":
        test_datasets[task] = (
            FewShotDataset(data_args[task], tokenizer=tokenizer, cache_dir="/data/common/lm-bff/model_files", mode="dev", use_demo=False)
        )
    else:
        test_datasets[task] = (
            FewShotDataset(data_args[task], tokenizer=tokenizer, cache_dir="/data/common/lm-bff/model_files", mode="test", use_demo=False)
        )
        
for epoch in range(epochs):
    print(f"Epoch: {epoch}")
    for task in task_list:
        print(f"Task: {task}")

        test_dataset = test_datasets[task]
        adamerging_mtl_model.model.label_word_list = torch.tensor(test_dataset.label_word_list).long().cuda()
        trainer = Trainer(model=model, eval_dataset=test_dataset)
        data_collator = trainer._get_collator_with_removed_columns(trainer.data_collator, description="evaluation")
        dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=data_collator, sampler=trainer._get_eval_sampler(test_dataset))

        for i, data in enumerate(tqdm.tqdm(dataloader)):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            mask_pos = data['mask_pos'].to(device)

            outputs = adamerging_mtl_model(input_ids, attention_mask, mask_pos)[0]
            loss = softmax_entropy(outputs).mean(0)

            if i > 0:  # Execute only one step
                break

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    print(list(adamerging_mtl_model.lambdas().data))

    if (epoch+1) % 100 == 0:
        path = root+f"/ckpt_paths/merged_models/layer_ada_all_0.3_{epoch+1}-merged-roberta-base-2-2e-5"
        state_dict = adamerging_mtl_model.get_model()
        # model_to_save = adamerging_mtl_model.get_model()
        model_to_save = initialize_model(modelname)
        model_to_save.load_state_dict(state_dict, strict=False)
        # model_to_save = adamerging_mtl_model.get_model()
        model_to_save.save_pretrained(path)
        # adamerging_mtl_model.model.save_pretrained(path)
        tokenizer.save_pretrained(path)
