import torch
import copy
import time
import math
import os,sys
current_dir = os.path.dirname(os.path.abspath(__file__))
top_level_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(top_level_dir)
from gpu import (
    add_gpu_params,
    parse_gpu,
    distributed_opt,
    distributed_gather,
    average_model,
    distributed_sync,
    cleanup
)
from optimizer import (
    create_adam_optimizer,
    create_optimizer_scheduler,
    add_optimizer_params,
    create_adam_optimizer_from_args
)
from gpt2_beam import beam
from gpt2_decode import decode_func


class AverageMeter(object):
    """Computes and stores the average and current value
         Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class Model():
    def __init__(self, args, model):
        super(Model, self).__init__()
        self.args = args
        self.model = model
        self.create_opt()

    def create_opt(self):
        self.trainable_params_A = []
        self.trainable_params_B = []
        self.trainable_params = []
        num_trainable_params = 0
        all_param = 0
        for name, param in self.model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2
            all_param += num_params
            if param.requires_grad:
                self.trainable_params.append(param)
                if not ('lora_C' in name or 'lora_D' in name):
                    self.trainable_params_A.append(param)
                else:
                    self.trainable_params_B.append(param)
                num_trainable_params += num_params
        if self.args.rank % self.args.world_size == 0:
            print(
                f"trainable params: {num_trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * num_trainable_params / all_param}")

    def evaluate(self, valid_loader, valid_data_text):
        print('Generate outputs ...')
        self.model.eval()
        all_prediction = beam(self.model, valid_loader, self.args)
        torch.save(all_prediction, f'eval/prediction_{self.args.ref_type}_{self.args.method}_{self.args.rank}.pkl')
        print('Decode the outputs ...')
        all_prediction = torch.load(f'eval/prediction_{self.args.ref_type}_{self.args.method}_{self.args.rank}.pkl')
        decode_func(self.args, all_prediction, valid_data_text)
        print('Evaluating ...')
        if self.args.ref_type == 'webnlg':
            from eval.GenerationEval.eval import evaluating
            output_ref_file = os.path.join(self.args.output_ref_file, 'reference')
            score = evaluating(self.args, output_ref_file, self.args.output_pred_file)
        else:
            from eval.e2e.measure_scores import evaluating
            score = evaluating(self.args, self.args.output_ref_file, self.args.output_pred_file)
        return score

    def average_model(self):
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if not ('lora_C' in n or 'lora_D' in n):
                    self.args.dist.reduce(p.data, dst=0, op=self.args.dist.ReduceOp.SUM)
                    p.data /= self.args.world_size
                    self.args.dist.broadcast(p.data, src=0)

    def train(self, train_loader, valid_loader, valid_data_text,  train_low_loader):
        start_time = time.time()
        self.optimizer_outer = create_adam_optimizer_from_args(self.trainable_params_A, self.args, self.args.lr)
        lr_scheduler = create_optimizer_scheduler(self.optimizer_outer, self.args, self.args.lr)
        self.optimizer_inner = create_adam_optimizer_from_args(self.trainable_params_B, self.args, self.args.lr_in)
        lr_scheduler_inner = create_optimizer_scheduler(self.optimizer_inner, self.args, self.args.lr_in)
        train_step = 0
        round = 0
        loss_list = []
        self.model.train()

        for epoch in range(self.args.max_epoch):
            print('start to train the model................', epoch)
            avg_lm_loss = AverageMeter()
            for idx, data in enumerate(train_loader):
                data = {key: value for key, value in data.items()}
                _input = data['input'].to(self.args.device)
                _target = data['target'].to(self.args.device)
                _msk = data['mask'].to(self.args.device)
                self.optimizer_outer.zero_grad()
                loss = self.learn_simple(_input, _target, _msk)
                avg_lm_loss.update(loss.item())
                train_step += 1
                self.optimizer_outer.step()
                lr_scheduler.step()
                lr_scheduler_inner.step()
                torch.cuda.empty_cache()

                if train_step % self.args.com_interval == 0:
                    avg_loss = distributed_gather(self.args, torch.tensor(avg_lm_loss.avg).to(self.args.device)).mean(0)
                    avg_loss_val = distributed_gather(self.args, torch.tensor(avg_lm_loss.val).to(self.args.device)).mean(0)
                    # average the upper-level variables
                    self.average_model()
                    elapsed = time.time() - start_time
                    lr = self.optimizer_outer.param_groups[0]['lr']
                    log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {idx + 1:>6d} batches | ' \
                              f'lr {lr:.3g} |  lr_in {self.args.lr_in:.3g} | times {elapsed:5.2f}s | ' \
                              f'loss {avg_loss_val:5.2f} | avg loss {avg_loss:5.2f} | ' \
                              f'ppl {math.exp(avg_loss):5.2f}'
                    if self.args.rank == 0:
                        print(log_str)
                        loss_list.append(avg_loss)
                    avg_lm_loss.reset()
                    distributed_sync(self.args)
                    round += 1
                if round >= self.args.com_rounds:
                    break
            if round >= self.args.com_rounds:
                break
        distributed_sync(self.args)
        self.average_model()
        results = self.evaluate(valid_loader, valid_data_text)
        distributed_sync(self.args)
        gathered_score = distributed_gather(self.args, torch.tensor(results).to(self.args.device))
        if self.args.rank == 0:
            avg_score = gathered_score.mean(dim=0)
            if self.args.ref_type == 'webnlg':
                metric_names = ['BLEU', 'MET', 'TER', 'ROUGE_L']
            else:
                metric_names = ['BLEU', 'NIST', 'METEOR', 'ROUGE_L', 'CIDEr']
            print('The average metrics: \n')
            for i, metric in enumerate(metric_names):
                print('%s: %.4f' % (metric, avg_score[i]))
            log_str = f'| Eval {train_step // self.args.eval_interval:3d} at step {train_step:>8d} | ' \
                      f'time: {time.time() - start_time:5.2f}s | evaluation result: {avg_score}'
        distributed_sync(self.args)

        if self.args.rank == 0:
            print('-' * 100)
            print(log_str)
            print('-' * 100)
            torch.save(loss_list, self.args.save_path + '.pkl')
            with open(self.args.save_path + '.txt', 'w') as f:
                f.write(str({'Exp config': str(self.args), 'log_info': str(log_str)}))

    def learn(self, input_ids, labels, attention_mask, train_low_loader, lr_in):
        model_copy = copy.deepcopy(self.model)
        model_copy.train()
        trainable_params_y = []
        trainable_params_x = []
        for n, p in model_copy.named_parameters():
            if p.requires_grad == True:
                if 'lora_C' in n or 'lora_D' in n:
                    trainable_params_y.append(p)
                else:
                    trainable_params_x.append(p)
        optimizer_2 = torch.optim.SGD( trainable_params_y, lr=self.args.lr_in, weight_decay=self.args.weight_decay)

        # one_step learning for inner_loops
        batch = next(iter(train_low_loader))
        data = {key: value for key, value in batch.items()}
        input_ids_in = data['input'].to(self.args.device)
        attention_mask_in = data['mask'].to(self.args.device)
        labels_in = data['target'].to(self.args.device)
        logit, loss_2 = model_copy(input_ids_in, lm_mask=attention_mask_in, lm_labels=labels_in, label_smooth=self.args.label_smooth)
        optimizer_2.zero_grad()
        loss_2.mean().backward()

        optimizer_2.step()
        model_copy.zero_grad()
        logit, loss_out = model_copy(input_ids, lm_mask=attention_mask, lm_labels=labels, label_smooth=self.args.label_smooth)
        loss_out.mean().backward()
        F_y = [y.grad for y in trainable_params_y]
        F_x = [x.grad for x in trainable_params_x]
        model_copy.zero_grad()
        batch = next(iter(train_low_loader))
        data = {key: value for key, value in batch.items()}
        input_ids_in = data['input'].to(self.args.device)
        attention_mask_in = data['mask'].to(self.args.device)
        labels_in = data['target'].to(self.args.device)
        logit, loss_in = self.model(input_ids_in, lm_mask=attention_mask_in, lm_labels=labels_in, label_smooth=self.args.label_smooth)
        G_y = torch.autograd.grad(loss_in.mean(), self.trainable_params_B, create_graph=True)
        JVP = torch.autograd.grad(G_y, self.trainable_params_A, grad_outputs=F_y)
        self.optimizer_outer.zero_grad()
        torch.cuda.empty_cache()
        for p, f_x, jvp in zip(self.trainable_params_A, F_x, JVP):
            p.grad = f_x - self.args.lr_in * jvp
        i = 0
        # Copy the updated lower-level variables to the original model
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if 'lora_C' in n or 'lora_D' in n:
                    p.data = trainable_params_y[i].detach().clone()
                    i += 1
        return loss_out.mean()

    def learn_simple(self, input_ids, labels, attention_mask):
        logit, loss_out = self.model(input_ids, lm_mask=attention_mask, lm_labels=labels, label_smooth=self.args.label_smooth)

        loss_out.mean().backward(retain_graph=True, create_graph=True)
        F_y = [x.grad for x in self.trainable_params_B]
        F_x = [x.grad for x in self.trainable_params_A]
        JVP = torch.autograd.grad(F_y, self.trainable_params_A, grad_outputs=F_y)
        self.optimizer_inner.zero_grad()
        self.optimizer_outer.zero_grad()
        torch.cuda.empty_cache()
        for p, f_x, jvp in zip(self.trainable_params_A, F_x, JVP):
            p.grad = f_x - self.args.lr_in * jvp

        self.optimizer_inner.step()
        self.optimizer_inner.zero_grad()
        return loss_out.mean()
