import copy
import shutil
import argparse
import logging
import math
import os
import random
import sys
import torch
import datasets
import transformers
from accelerate import Accelerator, DistributedType
from tqdm.auto import tqdm
from networks import prompt
from transformers import (
    MODEL_MAPPING,
    AdamW,
    AutoTokenizer,
    AutoConfig,
    RobertaTokenizer,
    BertTokenizer,
    DataCollatorForLanguageModeling,
    get_scheduler,
    SchedulerType,
    set_seed,
)
import numpy as np
from transformers import RobertaForMaskedLM, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from networks import prune_model
from networks.roberta_model import MyRoberta, MyRobertaForMaskedLM
from approaches.my_optimizer import MyAdamW
# sys.path.append("..")
from utils import utils

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


class Appr(object):

    def __init__(self, args):
        super().__init__()
        self.args = args
        self.tanh = torch.nn.Tanh()
        self.sigmoid = torch.nn.Sigmoid()
        return

    def importance_norm(self, importance, exponent=2):
        # Layerwise importance normalization
        if 'default' in self.args.pipline_norm:
            norm_by_layer = torch.pow(torch.pow(importance, exponent).sum(-1), 1 / exponent)
            importance /= norm_by_layer.unsqueeze(-1) + 1e-20
            importance = (importance - importance.min()) / (importance.max() - importance.min())

        elif 'standard_norm' in self.args.pipline_norm:
            for layer in range(importance.size(0)):
                importance[layer] = (importance[layer] - importance[layer].mean()) / importance[
                    layer].std()  # 2D, we need to deal with this for each layer
        # importance = self.tanh(importance).abs()

        return importance

    def use_heads_importance(self, accelerator):
        head_importance_list = []
        intermediate_importance_list = []
        output_importance_list = []

        for importance_dir in self.args.saved_output_dir:
            # head
            head_importance = torch.Tensor(np.load(os.path.join(importance_dir, "head_importance.npy"))).cuda()
            head_importance = self.importance_norm(head_importance)
            # if accelerator.is_main_process: np.savetxt(
            #     os.path.join(self.args.output_dir + '../base/', "head_importance_norm.txt"),
            #     head_importance.detach().cpu().numpy(), delimiter='\t')

            head_importance = self.tanh(head_importance).abs()
            # if accelerator.is_main_process: np.savetxt(
            #     os.path.join(self.args.output_dir + '../base/', "head_importance_tanh.txt"),
            #     head_importance.detach().cpu().numpy(), delimiter='\t')

            head_importance_list.append(head_importance)

            # intermediate
            intermediate_importance = torch.Tensor(
                np.load(os.path.join(importance_dir, "intermediate_importance.npy"))).cuda()
            intermediate_importance = self.importance_norm(intermediate_importance)
            # if accelerator.is_main_process: np.savetxt(
            #     os.path.join(self.args.output_dir + '../base/', "intermediate_importance_norm.txt"),
            #     intermediate_importance.detach().cpu().numpy(), delimiter='\t')

            intermediate_importance = self.tanh(intermediate_importance).abs()
            # if accelerator.is_main_process: np.savetxt(
            #     os.path.join(self.args.output_dir + '../base/', "intermediate_importance_tanh.txt"),
            #     intermediate_importance.detach().cpu().numpy(), delimiter='\t')

            intermediate_importance_list.append(intermediate_importance)

            # output
            output_importance = torch.Tensor(np.load(os.path.join(importance_dir, "output_importance.npy"))).cuda()
            output_importance = self.importance_norm(output_importance)
            # if accelerator.is_main_process: np.savetxt(
            #     os.path.join(self.args.output_dir + '../base/', "output_importance_norm.txt"),
            #     output_importance.detach().cpu().numpy(), delimiter='\t')

            output_importance = self.tanh(output_importance).abs()
            # if accelerator.is_main_process: np.savetxt(
            #     os.path.join(self.args.output_dir + '../base/', "output_importance_tanh.txt"),
            #     output_importance.detach().cpu().numpy(), delimiter='\t')

            output_importance_list.append(output_importance)

        head_importances = torch.stack(head_importance_list)
        head_importance, _ = head_importances.max(
            0)  # take a max, so that block all importance nurons for all previous tasks;
        # if you stack have to use this for element-wise version
        # you cannot input list to torch.max, unless you specify torch.max(a,b)

        intermediate_importances = torch.stack(intermediate_importance_list)
        intermediate_importance, _ = intermediate_importances.max(
            0)  # take a max, so that block all importance nurons for all previous tasks;

        output_importances = torch.stack(output_importance_list)
        output_importance, _ = output_importances.max(
            0)  # take a max, so that block all importance nurons for all previous tasks;

        return head_importance, intermediate_importance, output_importance, head_importance_list, intermediate_importance_list, output_importance_list

    def train(self, model, model_prune, accelerator, train_dataset, tokenizer, train_dataloader_prune,
              train_dataloader_prune_dup, train_dataloader):
        # ********************************* before tranining *********************************
        if self.args.task == 0 and 'proxy' in self.args.baseline:  # pre: proxy distillation
            model_prune, train_dataloader_prune = accelerator.prepare(model_prune, train_dataloader_prune)
            config = accelerator.unwrap_model(model_prune).model.config
            prune_model.compute_heads_importance(args=self.args, config=config, model=model_prune,
                                                 eval_dataloader=train_dataloader_prune, accelerator=accelerator,
                                                 position='pre', run_distill=True)

            accelerator.wait_for_everyone()
            # pre_head_importance = self.use_heads_importance()
            pre_head_importance, pre_intermediate_importance, pre_output_importance,head_importance_list, intermediate_importance_list, output_importance_list = self.use_heads_importance(
                accelerator)

            general_head_importance = torch.Tensor(
                np.load(os.path.join(self.args.saved_output_dir[0], "head_importance.npy"))).cuda()
            general_head_importance = self.importance_norm(general_head_importance)

            general_intermediate_importance = torch.Tensor(
                np.load(os.path.join(self.args.saved_output_dir[0], "intermediate_importance.npy"))).cuda()
            general_intermediate_importance = self.importance_norm(general_intermediate_importance)

            general_output_importance = torch.Tensor(
                np.load(os.path.join(self.args.saved_output_dir[0], "output_importance.npy"))).cuda()
            general_output_importance = self.importance_norm(general_output_importance)

            if accelerator.is_main_process:
                print('general_head_mask: ', general_head_importance)
                print('general_intermediate_mask: ', general_intermediate_importance)
                print('general_output_importance: ', general_output_importance)

                print('pre_head_importance: ', pre_head_importance)
                print('pre_intermediate_importance: ', pre_intermediate_importance)
                print('pre_output_importance: ', pre_output_importance)



        elif self.args.task > 0:  # nothing to use in the first task
            # pre_head_importance = self.use_heads_importance()
            pre_head_importance, pre_intermediate_importance, pre_output_importance,head_importance_list, intermediate_importance_list, output_importance_list = self.use_heads_importance(
                accelerator)

            general_head_location = os.path.join(self.args.saved_output_dir[0], "head_importance.npy")
            general_head_importance = torch.Tensor(np.load(general_head_location)).cuda()
            general_head_importance = self.importance_norm(general_head_importance)

            general_intermediate_importance = torch.Tensor(
                np.load(os.path.join(self.args.saved_output_dir[0], "intermediate_importance.npy"))).cuda()
            general_intermediate_importance = self.importance_norm(general_intermediate_importance)

            general_output_importance = torch.Tensor(
                np.load(os.path.join(self.args.saved_output_dir[0], "output_importance.npy"))).cuda()
            general_output_importance = self.importance_norm(general_output_importance)

            if accelerator.is_main_process:
                print('general_head_importance: ', general_head_importance)
                print('general_intermediate_importance: ', general_intermediate_importance)
                print('general_output_importance: ', general_output_importance)

                print('pre_head_importance: ', pre_head_importance)
                print('pre_intermediate_importance: ', pre_intermediate_importance)
                print('pre_output_importance: ', pre_output_importance)

        if 'pre_as_general' in self.args.baseline:
            general_head_importance = pre_head_importance
            general_intermediate_importance = pre_intermediate_importance
            general_output_importance = pre_output_importance

        # ********************************* before tranining *********************************

        # Optimizer
        # Split weights in two groups, one with weight decay and the other not.
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                'name': [n for n, p in model.named_parameters()
                         if p.requires_grad and not any(nd in n for nd in no_decay)],
                "params": [p for n, p in model.named_parameters()
                           if p.requires_grad and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": self.args.learning_rate
            },
            {
                'name': [n for n, p in model.named_parameters()
                         if p.requires_grad and any(nd in n for nd in no_decay)],
                "params": [p for n, p in model.named_parameters()
                           if p.requires_grad and any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": self.args.learning_rate

            }
        ]

        optimizer = AdamW(optimizer_grouped_parameters)

        # Prepare everything with our `accelerator`.
        model, optimizer, train_dataloader, train_dataloader_prune = accelerator.prepare(
            model, optimizer, train_dataloader, train_dataloader_prune
        )

        # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
        if accelerator.distributed_type == DistributedType.TPU:
            model.tie_weights()

        # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
        # shorter in multiprocess)

        # Scheduler and math around the number of training steps.
        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / self.args.gradient_accumulation_steps)

        if self.args.max_samples is not None:
            self.args.max_train_steps = self.args.max_samples // (
                    self.args.per_device_train_batch_size * accelerator.num_processes * self.args.gradient_accumulation_steps)

        if self.args.max_train_steps is None:
            self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch
        else:
            self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch)

        # TODO: Warm up can be important
        # warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1
        self.args.num_warmup_steps = int(float(self.args.warmup_proportion) * float(self.args.max_train_steps))  # 0.1

        lr_scheduler = get_scheduler(
            name=self.args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=self.args.num_warmup_steps,
            num_training_steps=self.args.max_train_steps,
        )

        # Train!
        total_batch_size = self.args.per_device_train_batch_size * accelerator.num_processes * self.args.gradient_accumulation_steps

        if accelerator.is_main_process:
            logger.info("***** Running training *****")
            logger.info(f"  Num examples = {len(train_dataset)}")
            logger.info(f"  Num Epochs = {self.args.num_train_epochs}")
            logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
            logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
            logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
            logger.info(f"  Total optimization steps = {self.args.max_train_steps}")
            logger.info(f"  Total samples = {self.args.max_train_steps * total_batch_size}")
            logger.info(
                f"  Learning Rate = {self.args.learning_rate}, Warmup Num = {self.args.num_warmup_steps}, Pre-trained Model = {self.args.model_name_or_path}")
            logger.info(
                f"  Seq ID = {self.args.idrandom}, Task id = {self.args.task}, dataset name = {self.args.dataset_name}")
            logger.info(f"  Baseline = {self.args.baseline}, Smax = {self.args.smax}")

        # Only show the progress bar once on each machine.
        progress_bar = tqdm(range(self.args.max_train_steps), disable=not accelerator.is_local_main_process)
        completed_steps = 0
        global_step = 0  # This will be used by CLMOE if we choose 'auto_encoder' as the route type.

        if accelerator.is_main_process:
            tensorboard_file = os.path.join(self.args.output_dir, str(self.args.dataset_name) + '_log')
            print('tensorboard_file: ', tensorboard_file)
            if os.path.isdir(tensorboard_file):
                shutil.rmtree(tensorboard_file)
            writer = utils.setup_writer(tensorboard_file)

        try:
            model_ori = accelerator.unwrap_model(model)
            if accelerator.is_main_process and (self.args.task > 0 or 'proxy' in self.args.baseline):
                print('head usage: ', (pre_head_importance.sum() / pre_head_importance.numel()).item())
                print('intermediate usage: ',
                      (pre_intermediate_importance.sum() / pre_intermediate_importance.numel()).item())
                print('output usage: ', (pre_output_importance.sum() / pre_output_importance.numel()).item())

            for epoch in range(self.args.num_train_epochs):
                # break
                model.train()
                for step, inputs in enumerate(train_dataloader):

                    if 'contrast' in self.args.baseline:
                        # outputs = model(inputs,general_head_mask=general_head_mask)
                        outputs = model(inputs,
                                        general_head_mask=general_head_importance,
                                        general_intermediate_mask=general_intermediate_importance,
                                        general_output_mask=general_output_importance,
                                        all_head_mask=head_importance_list,
                                        all_intermediate_mask=intermediate_importance_list,
                                        all_output_mask=output_importance_list
                                        )

                    else:
                        outputs = model(inputs)

                    loss = outputs.loss  # loss 1
                    if 'contrast' in self.args.baseline:
                        contrast_loss = outputs.contrast_loss  # loss 1
                        loss = loss + contrast_loss

                    loss = loss / self.args.gradient_accumulation_steps
                    # add model needs to be careful! make sure it is in parameters and please double check its gradient
                    accelerator.backward(loss)  # sync

                    # if accelerator.is_main_process:
                    #     for n,p in accelerator.unwrap_model(model).named_parameters():
                    #         if p.grad is not None:
                    #             print('n,p： ',n,p.size())
                    #

                    if self.args.task > 0 or 'proxy' in self.args.baseline:
                        n_layers, n_heads = model_ori.model.config.num_hidden_layers, model_ori.model.config.num_attention_heads
                        head_size = int(model_ori.model.config.hidden_size / model_ori.model.config.num_attention_heads)

                        for layer in range(n_layers):
                            head_importance = pre_head_importance[layer].unsqueeze(-1).repeat((1, head_size))
                            head_importance = head_importance.flatten()
                            head_mask = 1 - head_importance

                            model_ori.model.roberta.encoder.layer[layer].attention.self.query.weight.grad *= head_mask
                            model_ori.model.roberta.encoder.layer[layer].attention.self.query.bias.grad *= head_mask

                            model_ori.model.roberta.encoder.layer[layer].attention.self.key.weight.grad *= head_mask
                            model_ori.model.roberta.encoder.layer[layer].attention.self.key.bias.grad *= head_mask

                            model_ori.model.roberta.encoder.layer[layer].attention.self.value.weight.grad *= head_mask
                            model_ori.model.roberta.encoder.layer[layer].attention.self.value.bias.grad *= head_mask

                            model_ori.model.roberta.encoder.layer[layer].attention.output.dense.weight.grad *= head_mask
                            model_ori.model.roberta.encoder.layer[layer].attention.output.dense.bias.grad *= head_mask

                            intermediate_mask = (1 - pre_intermediate_importance[layer])
                            model_ori.model.roberta.encoder.layer[
                                layer].intermediate.dense.weight.grad *= intermediate_mask.unsqueeze(1)
                            model_ori.model.roberta.encoder.layer[
                                layer].intermediate.dense.bias.grad *= intermediate_mask
                            # compute_mask(model_ori.model.roberta.encoder.layer[layer].intermediate.dense.bias.grad,intermediate_importance)

                            output_mask = (1 - pre_output_importance[layer])
                            model_ori.model.roberta.encoder.layer[
                                layer].output.dense.weight.grad *= output_mask.unsqueeze(1)
                            model_ori.model.roberta.encoder.layer[layer].output.dense.bias.grad *= output_mask

                    global_step += 1

                    if step % self.args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:

                        optimizer.step()
                        lr_scheduler.step()

                        optimizer.zero_grad()
                        progress_bar.update(1)
                        completed_steps += 1
                        progress_bar.set_description(
                            'Train Iter (loss=%5.3f)' % loss.item())  # show the loss, mean while

                        if accelerator.is_main_process:
                            utils.log_loss(writer, scalar_value=loss.item(), global_step=global_step)
                            utils.log_loss(writer, loss_name=' MLM loss', scalar_value=outputs.loss.item(),
                                           global_step=global_step)
                            if 'contrast' in self.args.baseline:
                                utils.log_loss(writer, loss_name=' contrastive loss',
                                               scalar_value=outputs.contrast_loss.item(), global_step=global_step)
                    # break
                    if completed_steps >= self.args.max_train_steps:
                        break

            if self.args.task == 0 and 'proxy' in self.args.baseline:  # need to do twice, for the first task
                train_dataloader_prune_dup = accelerator.prepare(train_dataloader_prune_dup)
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune_dup)
            else:
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune)

        except KeyboardInterrupt:  # even if contro-C, I still want to save model
            return
            if self.args.task == 0 and 'proxy' in self.args.baseline:
                train_dataloader_prune_dup = accelerator.prepare(train_dataloader_prune_dup)
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune_dup)
            else:
                self.after_training_op(accelerator, model, tokenizer, train_dataloader_prune)
            # pass

    def after_training_op(self, accelerator, model, tokenizer, train_dataloader_prune):
        # utils.print_model_report( accelerator.unwrap_model(model))
        # exit()

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:  # onlyh discriminator is saved. I don't need anything about geenrator
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.model.save_pretrained(self.args.output_dir)
            tokenizer.save_pretrained(self.args.output_dir)

        config = accelerator.unwrap_model(model).model.roberta.config

        if 'random' in self.args.baseline:
            # random mask, no need to run compute_heads_importance
            head_importance = torch.rand(12,12).cuda()
            intermediate_importance = torch.rand(12,config.intermediate_size).cuda()
            output_importance = torch.rand(12,config.hidden_size).cuda()

            if accelerator.is_main_process:
                np.save(os.path.join(self.args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
                np.save(os.path.join(self.args.output_dir, "intermediate_importance.npy"),intermediate_importance.detach().cpu().numpy())
                np.save(os.path.join(self.args.output_dir, "output_importance.npy"),output_importance.detach().cpu().numpy())

        else:
            prune_model.compute_heads_importance(args=self.args, config=config, model=model,
                                                 eval_dataloader=train_dataloader_prune, accelerator=accelerator,
                                                 position='post')
            # the "pre_head_importance" we got is for generalized knowledge, now we want to get domain-specific knowledge

    # args.ft_task, model, eval_dataloader, eval_dataset, accelerator
    def eval(self, model, eval_dataloader, eval_dataset, accelerator):
        # Note MLM has randomness

        model, eval_dataloader = accelerator.prepare(model, eval_dataloader)

        model.eval()
        losses = []

        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(batch)
            loss = outputs.loss
            losses.append(accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size)))

        losses = torch.cat(losses)
        losses = losses[: len(eval_dataset)]
        try:
            perplexity = math.exp(torch.mean(losses))
        except OverflowError:
            perplexity = float("inf")

        return perplexity

    def save_ppl(self, t, u, perplexity):
        # TODO: need double check, in particular read and save

        logger.info(f"Pre-trained: {t}, Test: {u}, perplexity: {perplexity}")

        if self.args.pt_task == -1:  # base, no posst-train
            progressive_ppl_path = os.path.join(self.args.output_dir, 'progressive_ppl_' + str(self.args.seed))
        else:
            progressive_ppl_path = os.path.join(self.args.output_dir + '../', 'progressive_ppl_' + str(self.args.seed))
        print('progressive_ppl_path: ', progressive_ppl_path)

        if os.path.exists(progressive_ppl_path):
            print('loading *************')
            ppls = np.loadtxt(progressive_ppl_path)
        else:
            ppls = np.zeros((self.args.ntasks, self.args.ntasks), dtype=np.float32)

        ppls[t][u] = perplexity
        np.savetxt(progressive_ppl_path, ppls, '%.4f', delimiter='\t')

        if u == self.args.ntasks - 1:  # last ft task, we need a final one
            if self.args.pt_task == -1:  # base, no posst-train
                final_f1 = os.path.join(self.args.output_dir, 'ppl_' + str(self.args.seed))
                forward_f1 = os.path.join(self.args.output_dir, 'forward_ppl_' + str(self.args.seed))
            else:
                final_f1 = os.path.join(self.args.output_dir + '../', 'ppl_' + str(self.args.seed))
                forward_f1 = os.path.join(self.args.output_dir + '../', 'forward_ppl_' + str(self.args.seed))
            print('final_f1: ', final_f1)

            with open(final_f1, 'w') as f1_file:
                for j in range(ppls.shape[1]):
                    f1_file.writelines(str(ppls[-1][j]) + '\n')

            with open(forward_f1, 'w') as f1_file:
                for j in range(ppls.shape[1]):
                    f1_file.writelines(str(ppls[j][j]) + '\n')
