import copy
import shutil
import argparse
import logging
import math
import os
import random
import sys
import torch
import datasets
import transformers
from accelerate import Accelerator, DistributedType
from tqdm.auto import tqdm
from networks import prompt
from transformers import (
    MODEL_MAPPING,
    AdamW,
    AutoTokenizer,
    AutoConfig,
    RobertaTokenizer,
    BertTokenizer,
    DataCollatorForLanguageModeling,
    get_scheduler,
    SchedulerType,
    set_seed,
)
import numpy as np
from transformers import RobertaForMaskedLM, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from networks import prune_model
from networks.roberta_model import MyRoberta, MyRobertaForMaskedLM
from approaches.my_optimizer import MyAdamW
# sys.path.append("..")
from utils import utils
from networks import fisher_model,hat_model,derpp_model,demix_model
from networks.buffer import Buffer as Buffer

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


class Appr(object):

    def __init__(self, args):
        super().__init__()
        self.args = args
        self.tanh = torch.nn.Tanh()
        self.sigmoid = torch.nn.Sigmoid()


        return


    def mask(self,model,accelerator):
        # print('sargs.s: ', self.args.s)


        model_ori = accelerator.unwrap_model(model)

        masks = {}

        if 'transformer_hat' in self.args.baseline:
            head_mask, intermediate_mask, output_mask = model_ori.model.transformer_mask()

        for layer_id in range(model_ori.config.num_hidden_layers):
            if 'adapter_hat' in self.args.baseline:
                fc1_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.moe.adapters.fc1' #gfc1
                fc2_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.moe.adapters.fc2' #gfc2

                masks[fc1_key],masks[fc2_key] = model_ori.model.roberta.encoder.layer[layer_id].attention.output.moe.adapters.mask()

                fc1_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.output.moe.adapters.fc1' #gfc1
                fc2_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.output.moe.adapters.fc2' #gfc2

                masks[fc1_key],masks[fc2_key] = model_ori.model.roberta.encoder.layer[layer_id].output.moe.adapters.mask()

            elif 'transformer_hat' in self.args.baseline:

                query_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.query' #gfc1
                key_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.key' #gfc1
                value_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.value' #gfc1
                dense_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.dense' #gfc1

                intermediate_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.intermediate.dense' #gfc1
                output_key = 'module.model.roberta.encoder.layer.'+str(layer_id)+'.output.dense' #gfc1


                head_size = int(model_ori.model.config.hidden_size / model_ori.model.config.num_attention_heads)


                masks[query_key] = head_mask[layer_id].unsqueeze(-1).repeat((1, head_size)).flatten()
                masks[key_key] = head_mask[layer_id].unsqueeze(-1).repeat((1, head_size)).flatten()
                masks[value_key] = head_mask[layer_id].unsqueeze(-1).repeat((1, head_size)).flatten()
                masks[dense_key] = head_mask[layer_id].unsqueeze(-1).repeat((1, head_size)).flatten()

                masks[intermediate_key] = intermediate_mask[layer_id]
                masks[output_key] = output_mask[layer_id]

        return masks


    def get_view_for(self,n,p,masks,config):
        for layer_id in range(config.num_hidden_layers):

            if 'adapter_hat' in self.args.baseline:
                if n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.moe.adapters.fc1.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.moe.adapters.fc1.bias':
                    return masks[n.replace('.bias','')].data.view(-1)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.moe.adapters.fc2.weight':
                    post=masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                    pre=masks[n.replace('.weight','').replace('fc2','fc1')].data.view(1,-1).expand_as(p)
                    return torch.min(post,pre)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.moe.adapters.fc2.bias':
                    return masks[n.replace('.bias','')].data.view(-1)

                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.output.moe.adapters.fc1.weight':
                    # print('not nont')
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.output.moe.adapters.fc1.bias':
                    return masks[n.replace('.bias','')].data.view(-1)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.output.moe.adapters.fc2.weight':
                    post=masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                    pre=masks[n.replace('.weight','').replace('fc2','fc1')].data.view(1,-1).expand_as(p)
                    return torch.min(post,pre)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.output.moe.adapters.fc2.bias':
                    return masks[n.replace('.bias','')].data.view(-1)

            elif 'transformer_hat' in self.args.baseline:
                if n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.query.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.query.bias':
                    return masks[n.replace('.bias','')].data.view(-1)

                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.key.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n == 'module.model.roberta.encoder.layer.' + str(layer_id) + '.attention.self.key.weight':
                    return masks[n.replace('.bias','')].data.view(-1)

                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.self.value.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n == 'module.model.roberta.encoder.layer.' + str(layer_id) + '.attention.self.value.weight':
                    return masks[n.replace('.bias','')].data.view(-1)

                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.attention.output.dense.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n == 'module.model.roberta.encoder.layer.' + str(layer_id) + '.attention.output.dense.bias':
                    return masks[n.replace('.bias','')].data.view(-1)

                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.intermediate.dense.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n == 'module.model.roberta.encoder.layer.' + str(layer_id) + '.intermediate.dense.bias':
                    return masks[n.replace('.bias','')].data.view(-1)

                elif n=='module.model.roberta.encoder.layer.'+str(layer_id)+'.output.dense.weight':
                    return masks[n.replace('.weight','')].data.view(-1,1).expand_as(p)
                elif n == 'module.model.roberta.encoder.layer.' + str(layer_id) + '.output.dense.bias':
                    return masks[n.replace('.bias','')].data.view(-1)


        return None





    def train(self, model,model_prune, accelerator, train_dataset, tokenizer, train_dataloader_prune,train_dataloader_prune_dup, train_dataloader_prune_dataset,train_dataloader):

        # ********************************* before tranining *********************************

        if 'ewc' in self.args.baseline:
            if os.path.exists(os.path.join(self.args.output_dir+'../', 'fisher')):
                print('load fisher matrix **************')
                self_fisher = torch.load(os.path.join(self.args.output_dir+'../', 'fisher'))
                for k,v in self_fisher.items():
                    self_fisher[k] = self_fisher[k].cuda()
            else:
                self_fisher = None

        elif 'adapter_hat' in self.args.baseline \
                or 'transformer_hat' in self.args.baseline \
                or 'adapter_bcl' in self.args.baseline\
                or 'adapter_classic' in self.args.baseline: #BCL included HAT
            print('load mask matrix **************')
            if os.path.exists(os.path.join(self.args.output_dir+'../', 'mask_pre')):
                mask_pre = torch.load(os.path.join(self.args.output_dir+'../', 'mask_pre'))
                mask_back = torch.load(os.path.join(self.args.output_dir+'../', 'mask_back'))

                for k,v in mask_pre.items():
                    mask_pre[k] = mask_pre[k].cuda()

                for k,v in mask_back.items():
                    mask_back[k] = mask_back[k].cuda()
            else:
                mask_pre = None
                mask_back = None

        elif 'derpp' in self.args.baseline:
            if self.args.task == 0:
                buffer = Buffer(int(50 * 8), 'cuda')
            else:
                buffer = torch.load(os.path.join(self.args.output_dir + '../', 'buffer'))

        elif self.args.task > 0 and 'adapter_demix' in self.args.baseline: # initialize the new adapter using the nearest adapter
            model = demix_model.demix_compute(train_dataloader_prune, train_dataloader_prune_dataset, model, accelerator, self.args)

        # ********************************* before tranining *********************************


        # Optimizer
        # Split weights in two groups, one with weight decay and the other not.
        no_decay = ["bias", "LayerNorm.weight"]
        prompt_w = "prompt_embed_pool"

        optimizer_grouped_parameters = [
            {
                'name': [n for n, p in model.named_parameters()
                         if p.requires_grad and not any(nd in n for nd in no_decay)  and prompt_w not in n],
                "params": [p for n, p in model.named_parameters()
                           if p.requires_grad and not any(nd in n for nd in no_decay) and prompt_w not in n],
                "weight_decay": self.args.weight_decay,
                "lr": self.args.learning_rate
            },
            {
                'name': [n for n, p in model.named_parameters()
                         if p.requires_grad and any(nd in n for nd in no_decay) and prompt_w not in n],
                "params": [p for n, p in model.named_parameters()
                           if p.requires_grad and any(nd in n for nd in no_decay) and prompt_w not in n],
                "weight_decay": 0.0,
                "lr": self.args.learning_rate

            },
            {
                "params": [p for n, p in model.named_parameters()
                           if p.requires_grad and prompt_w in n],
                "lr": 0.3,  # must use a higher lr
            }
        ]

        optimizer = AdamW(optimizer_grouped_parameters)

        # Prepare everything with our `accelerator`.
        model, optimizer, train_dataloader, train_dataloader_prune = accelerator.prepare(
            model, optimizer, train_dataloader, train_dataloader_prune
        )


        # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
        if accelerator.distributed_type == DistributedType.TPU:
            model.tie_weights()

        # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
        # shorter in multiprocess)

        # Scheduler and math around the number of training steps.
        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / self.args.gradient_accumulation_steps)

        if self.args.max_samples is not None:
            self.args.max_train_steps = self.args.max_samples // (
                    self.args.per_device_train_batch_size * accelerator.num_processes * self.args.gradient_accumulation_steps)

        if self.args.max_train_steps is None:
            self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch
        else:
            self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch)

        # TODO: Warm up can be important
        # warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1
        self.args.num_warmup_steps = int(float(self.args.warmup_proportion) * float(self.args.max_train_steps))  # 0.1

        lr_scheduler = get_scheduler(
            name=self.args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=self.args.num_warmup_steps,
            num_training_steps=self.args.max_train_steps,
        )

        # Train!
        total_batch_size = self.args.per_device_train_batch_size * accelerator.num_processes * self.args.gradient_accumulation_steps

        if accelerator.is_main_process:
            logger.info("***** Running training *****")
            logger.info(f"  Num examples = {len(train_dataset)}")
            logger.info(f"  Num Epochs = {self.args.num_train_epochs}")
            logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
            logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
            logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
            logger.info(f"  Total optimization steps = {self.args.max_train_steps}")
            logger.info(f"  Total samples = {self.args.max_train_steps * total_batch_size}")
            logger.info(
                f"  Learning Rate = {self.args.learning_rate}, Warmup Num = {self.args.num_warmup_steps}, Pre-trained Model = {self.args.model_name_or_path}")
            logger.info(
                f"  Seq ID = {self.args.idrandom}, Task id = {self.args.task}, dataset name = {self.args.dataset_name}")
            logger.info(f"  Baseline = {self.args.baseline}, Smax = {self.args.smax}")

        # Only show the progress bar once on each machine.
        progress_bar = tqdm(range(self.args.max_train_steps), disable=not accelerator.is_local_main_process)
        completed_steps = 0
        global_step = 0  # This will be used by CLMOE if we choose 'auto_encoder' as the route type.

        if accelerator.is_main_process:
            tensorboard_file = os.path.join(self.args.output_dir, str(self.args.dataset_name) + '_log')
            print('tensorboard_file: ', tensorboard_file)
            if os.path.isdir(tensorboard_file):
                shutil.rmtree(tensorboard_file)
            writer = utils.setup_writer(tensorboard_file)

        try:

            for epoch in range(self.args.num_train_epochs):
                # break
                model.train()
                for step, inputs in enumerate(train_dataloader):
                    self.args.s = (self.args.smax - 1 / self.args.smax) * step / len(train_dataloader) + 1 / self.args.smax

                    if 'ewc' in self.args.baseline:
                        outputs = model(inputs,self_fisher=self_fisher)
                    elif 'adapter_hat' in self.args.baseline \
                            or 'adapter_bcl' in self.args.baseline\
                            or 'adapter_classic' in self.args.baseline:
                        masks = self.mask(model,accelerator)
                        outputs = model(inputs,masks=masks,mask_pre=mask_pre)
                    elif 'derpp' in self.args.baseline:
                        outputs = model(inputs,buffer=buffer)
                    elif 'transformer_hat' in self.args.baseline:
                        model_ori = accelerator.unwrap_model(model)
                        head_importance,intermediate_importance,output_importance = model_ori.model.transformer_mask()
                        outputs = model(inputs,head_mask=head_importance,intermediate_mask=intermediate_importance,output_mask=output_importance)

                    else:
                        outputs = model(inputs)

                    loss = outputs.loss  # loss 1
                    if 'distill' in self.args.baseline:
                        distill_loss = outputs.distill_loss  # loss 1
                        loss = loss + distill_loss
                    if 'simcse' in self.args.baseline:
                        simcse_loss = outputs.simcse_loss  # loss 1
                        loss = loss + simcse_loss


                    loss = loss / self.args.gradient_accumulation_steps
                    # add model needs to be careful! make sure it is in parameters and please double check its gradient
                    accelerator.backward(loss)  # sync


                    # Restrict layer gradients in backprop
                    if self.args.task > 0 and \
                            ('adapter_hat' in self.args.baseline
                             or 'transformer_hat' in self.args.baseline
                             or 'adapter_bcl' in self.args.baseline
                             or 'adapter_classic' in self.args.baseline):
                        for n, p in model.named_parameters():
                            if n in mask_back and p.grad is not None:
                                p.grad.data *= mask_back[n]

                    if 'adapter_hat' in self.args.baseline \
                            or 'transformer_hat' in self.args.baseline \
                            or 'adapter_bcl' in self.args.baseline\
                            or 'adapter_classic' in self.args.baseline:
                        # Compensate embedding gradients
                        for n, p in model.named_parameters():
                            if 'adapters.e' in n or ('model.e' in n and p.grad is not None):
                                num = torch.cosh(torch.clamp(self.args.s * p.data, -self.args.thres_cosh, self.args.thres_cosh)) + 1
                                den = torch.cosh(p.data) + 1
                                p.grad.data *= self.args.smax / self.args.s * num / den


                    # if accelerator.is_main_process:
                    #     for n,p in model.named_parameters():
                    #         if p.grad is not None:
                    #             print('n,p： ',n,p.size())

                    global_step += 1

                    if step % self.args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:

                        optimizer.step()
                        lr_scheduler.step()

                        optimizer.zero_grad()
                        progress_bar.update(1)
                        completed_steps += 1

                        if 'adapter_hat' in self.args.baseline \
                                or 'transformer_hat' in self.args.baseline \
                                or 'adapter_bcl' in self.args.baseline\
                                or 'adapter_classic' in self.args.baseline:
                            # Constrain embeddings
                            for n, p in model.named_parameters():
                                if 'adapters.e' in n or 'model.e' in n:
                                    p.data = torch.clamp(p.data, -self.args.thres_emb, self.args.thres_emb)



                        progress_bar.set_description(
                            'Train Iter (loss=%5.3f)' % loss.item())  # show the loss, mean while

                        if accelerator.is_main_process:
                            utils.log_loss(writer, scalar_value=loss.item(), global_step=global_step)
                            utils.log_loss(writer, loss_name=' MLM loss', scalar_value=outputs.loss.item(),global_step=global_step)
                            if 'distill' in self.args.baseline:
                                utils.log_loss(writer, loss_name=' distill loss', scalar_value=outputs.distill_loss.item(),global_step=global_step)
                            if 'simcse' in self.args.baseline:
                                utils.log_loss(writer, loss_name=' simcse loss', scalar_value=outputs.simcse_loss.item(),global_step=global_step)

                    # break
                    if completed_steps >= self.args.max_train_steps:
                        break

            if 'ewc' in self.args.baseline:
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune, self_fisher=self_fisher)
            elif 'adapter_hat' in self.args.baseline \
                    or 'transformer_hat' in self.args.baseline \
                    or 'adapter_bcl' in self.args.baseline\
                    or 'adapter_classic' in self.args.baseline:
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune, mask_pre=mask_pre)
            elif 'derpp' in self.args.baseline:
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune, buffer=buffer)
            else:
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune)

        except KeyboardInterrupt:  # even if contro-C, I still want to save model
            return

            # pass

    def after_training_op(self, accelerator, model, tokenizer, train_dataloader_prune, self_fisher=None, mask_pre=None, buffer=None):
        # utils.print_model_report( accelerator.unwrap_model(model))
        # exit()

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:  # onlyh discriminator is saved. I don't need anything about geenrator
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.model.save_pretrained(self.args.output_dir)
            tokenizer.save_pretrained(self.args.output_dir)


        # Fisher ops
        if 'ewc' in self.args.baseline:
            fisher_model.fisher_compute(train_dataloader_prune,model, self_fisher, accelerator, self.args)
        elif 'adapter_hat' in self.args.baseline \
                or 'transformer_hat' in self.args.baseline \
                or 'adapter_bcl' in self.args.baseline\
                or 'adapter_classic' in self.args.baseline:
            self.args.s = self.args.smax
            mask = self.mask(model,accelerator)
            hat_model.hat_compute(model, accelerator, mask_pre, mask, self.get_view_for, self.args)
        elif 'derpp' in self.args.baseline:
            # add data to the buffer
            if accelerator.is_main_process: # only find some to keep, no training
                derpp_model.derpp_compute(train_dataloader_prune, model, buffer, self.args)



    # args.ft_task, model, eval_dataloader, eval_dataset, accelerator
    def eval(self, model, eval_dataloader, eval_dataset, accelerator):
        # Note MLM has randomness

        model, eval_dataloader = accelerator.prepare(model, eval_dataloader)

        model.eval()
        losses = []

        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(batch)
            loss = outputs.loss
            losses.append(accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size)))

        losses = torch.cat(losses)
        losses = losses[: len(eval_dataset)]
        try:
            perplexity = math.exp(torch.mean(losses))
        except OverflowError:
            perplexity = float("inf")

        return perplexity

    def save_ppl(self, t, u, perplexity):
        # TODO: need double check, in particular read and save

        logger.info(f"Pre-trained: {t}, Test: {u}, perplexity: {perplexity}")

        if self.args.pt_task == -1:  # base, no posst-train
            progressive_ppl_path = os.path.join(self.args.output_dir, 'progressive_ppl_' + str(self.args.seed))
        else:
            progressive_ppl_path = os.path.join(self.args.output_dir + '../', 'progressive_ppl_' + str(self.args.seed))
        print('progressive_ppl_path: ', progressive_ppl_path)

        if os.path.exists(progressive_ppl_path):
            print('loading *************')
            ppls = np.loadtxt(progressive_ppl_path)
        else:
            ppls = np.zeros((self.args.ntasks, self.args.ntasks), dtype=np.float32)

        ppls[t][u] = perplexity
        np.savetxt(progressive_ppl_path, ppls, '%.4f', delimiter='\t')

        if u == self.args.ntasks - 1:  # last ft task, we need a final one
            if self.args.pt_task == -1:  # base, no posst-train
                final_f1 = os.path.join(self.args.output_dir, 'ppl_' + str(self.args.seed))
                forward_f1 = os.path.join(self.args.output_dir, 'forward_ppl_' + str(self.args.seed))
            else:
                final_f1 = os.path.join(self.args.output_dir + '../', 'ppl_' + str(self.args.seed))
                forward_f1 = os.path.join(self.args.output_dir + '../', 'forward_ppl_' + str(self.args.seed))
            print('final_f1: ', final_f1)

            with open(final_f1, 'w') as f1_file:
                for j in range(ppls.shape[1]):
                    f1_file.writelines(str(ppls[-1][j]) + '\n')

            with open(forward_f1, 'w') as f1_file:
                for j in range(ppls.shape[1]):
                    f1_file.writelines(str(ppls[j][j]) + '\n')
