import torch
import copy
import time
import math
import os,sys
import torch.nn as nn
import torch.nn.functional as F
current_dir = os.path.dirname(os.path.abspath(__file__))
top_level_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(top_level_dir)
from gpu import (
    add_gpu_params,
    parse_gpu,
    distributed_opt,
    distributed_gather,
    average_model,
    distributed_sync,
    cleanup
)
from optimizer import (
    create_adam_optimizer,
    create_optimizer_scheduler,
    add_optimizer_params,
    create_adam_optimizer_from_args
)
from gpt2_beam import beam
from gpt2_decode import decode_func
# from eval.e2e.measure_scores import evaluating
from eval.GenerationEval.eval import evaluating

class AverageMeter(object):
    """Computes and stores the average and current value
         Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Model():
    def __init__(self, args, model):
        super(Model, self).__init__()
        self.args = args
        self.model = model
        self.current_rank = self.args.lora_dim
        self.create_opt()
        self.past_regularization = self.regularization()
        # self.layered_rank = [self.args.rank_mat for name, param in self.model.named_parameters() if 'lora_A' in name or 'lora_B' in name]

    def regularization(self):
        x = 0.0
        for (b, a) in zip(self.trainable_params_B, self.trainable_params_A):
            x += (a[math.floor(a.size(0) * self.args.gamma):a.size(0), :]).norm() *\
                 (b[:, math.floor(b.size(1) * self.args.gamma):b.size(1)]).norm()
        return x

    def create_opt(self):
        self.trainable_params_A = []
        self.trainable_params_B = []
        num_trainable_params = 0
        all_param = 0
        for name, param in self.model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel
            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params

            if param.requires_grad:
                if 'lora_A' in name :
                    self.trainable_params_A.append(param)
                    # print(f'A: {param.size()}')
                if 'lora_B' in name:
                    self.trainable_params_B.append(param)
                    # print(f'B: {param.size()}')
                num_trainable_params += num_params

        distributed_sync(self.args)
        num_trainable_params = distributed_gather(self.args, torch.tensor(float(num_trainable_params)).to(self.args.device)).mean(0)
        distributed_sync(self.args)
        if self.args.rank % self.args.world_size == 0:
            print(f"{ self.args.rank}: trainable params: {num_trainable_params} || all params: {all_param:,d} || trainable%: {100 * num_trainable_params / all_param}")

    def evaluate(self, valid_loader, valid_data_text):
        print('Generate outputs ...')
        self.model.eval()
        all_prediction = beam(self.model, valid_loader, self.args)
        torch.save(all_prediction, f'eval/prediction_{self.args.ref_type}_{self.args.method}_{self.args.rank}.pkl')
        print('Decode the outputs ...')
        all_prediction = torch.load(f'eval/prediction_{self.args.ref_type}_{self.args.method}_{self.args.rank}.pkl')
        decode_func(self.args, all_prediction, valid_data_text)
        print('Evaluating ...')
        if self.args.ref_type == 'webnlg':
            from eval.GenerationEval.eval import evaluating
            output_ref_file = os.path.join(self.args.output_ref_file, 'reference')
            score = evaluating(self.args, output_ref_file, self.args.output_pred_file)
        else:
            from eval.e2e.measure_scores import evaluating
            score = evaluating(self.args, self.args.output_ref_file, self.args.output_pred_file)
        return score


    def start_local_steps(self):
        # get the regularization at the beginning of each local steps.
        self.past_regularization = self.regularization()


    def _get_module_and_param_name(self, name):
        parts = name.split('.')
        module = self.model
        for part in parts[:-1]:
            module = getattr(module, part)
        return module, parts[-1]

    def end_local_step(self):
        x = self.regularization()

        # prune the lora_A and lora_B if the regularization becomes smaller
        if  self.past_regularization - x >  0 and self.current_rank > self.args.rank_min:
            print(f'{self.args.rank} pruning')

            i,j = 0, 0
            for name, param in self.model.named_parameters():
                if 'lora_A' in name:
                    self.current_rank_A = math.floor(param.size(0) // 2 * self.args.gamma) * 2
                    with torch.no_grad():
                        param.data = param.data[:self.current_rank_A, :]
                    # self.trainable_params_A.append(new_layer.data)
                    new_param = nn.Parameter(param.data)
                    self.trainable_params_A[i] = new_param
                    parent_module, param_name = self._get_module_and_param_name(name)
                    parent_module.register_parameter(param_name, new_param)
                    i += 1
                if 'lora_B' in name:
                    self.current_rank = math.floor(param.size(1) * self.args.gamma)
                    with torch.no_grad():
                        param.data = param.data[:, :self.current_rank]
                    new_param = nn.Parameter(param.data)
                    self.trainable_params_B[j] = new_param
                    parent_module, param_name = self._get_module_and_param_name(name)
                    parent_module.register_parameter(param_name, new_param)
                    j += 1

        print(f'rank: {self.args.rank}, r: {self.current_rank}')


    def train(self, train_loader, valid_loader, valid_data_text,):
        start_time = time.time()
        optimizer_grouped_parameters = [
            {
                "params": self.trainable_params_A,
                "weight_decay": self.args.weight_decay,
                "lr": self.args.lr,
            },
            {
                "params": self.trainable_params_B,
                "weight_decay": self.args.weight_decay,
                "lr": self.args.lr,
            }
        ]
        self.optimizer_outer = torch.optim.AdamW(optimizer_grouped_parameters)
        lr_scheduler = create_optimizer_scheduler(self.optimizer_outer, self.args, self.args.lr)

        loss_list = []
        round = 0
        train_step = 0
        self.model.train()
        for epoch in range(self.args.max_epoch):
            print('start to train the model................', epoch)
            avg_lm_loss = AverageMeter()
            for idx, data in enumerate(train_loader):
                data = {key: value for key, value in data.items()}
                _input = data['input'].to(self.args.device)
                _target = data['target'].to(self.args.device)
                _msk = data['mask'].to(self.args.device)
                self.optimizer_outer.zero_grad()
                _lm_logits, _lm_loss = self.model(_input, lm_labels=_target, lm_mask=_msk, label_smooth=self.args.label_smooth)
                _lm_loss = _lm_loss.mean() + self.args.lamb * self.regularization()
                avg_lm_loss.update(_lm_loss.item())
                train_step += 1
                _lm_loss.backward()
                self.optimizer_outer.step()
                self.optimizer_outer.zero_grad()
                lr_scheduler.step()

                if train_step % self.args.com_interval == 0:
                    avg_loss = distributed_gather(self.args, torch.tensor(avg_lm_loss.avg).to(self.args.device)).mean(0)
                    avg_loss_val = distributed_gather(self.args, torch.tensor(avg_lm_loss.val).to(self.args.device)).mean(0)
                    self.end_local_step()
                    norm_BA = [F.conv1d(a.unsqueeze(0), b.unsqueeze(-1), groups=2 ).squeeze(0).norm() \
                               for (b, a) in zip(self.trainable_params_B, self.trainable_params_A)]
                    personal_norm_BA = copy.copy(norm_BA)
                    # compute the sum of all the clients' BA norm
                    for i in range(len(norm_BA)):
                        self.args.dist.reduce(norm_BA[i], dst=0, op=self.args.dist.ReduceOp.SUM)
                        self.args.dist.broadcast(norm_BA[i], src=0)
                    i, j = 0, 0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad == True:
                            if 'lora_A' in n:
                                zero_tensor = torch.zeros(self.args.rank_max*2-p.size(0), p.size(1), dtype=p.dtype, device=p.device)
                                send_p = torch.cat([p, zero_tensor], dim=0)/self.args.world_size
                                send_p *= personal_norm_BA[i]/norm_BA[i]
                                i += 1
                            else:
                                zero_tensor = torch.zeros(p.size(0), self.args.rank_max-p.size(1),  dtype=p.dtype, device=p.device)
                                send_p = torch.cat([p, zero_tensor], dim=1)/self.args.world_size
                                send_p *= personal_norm_BA[j]/norm_BA[j]
                                j += 1

                            self.args.dist.reduce(send_p.data, dst=0, op=self.args.dist.ReduceOp.SUM)
                            self.args.dist.broadcast(send_p.data, src=0)

                            if 'lora_A' in n:
                                p.data = send_p.data[:p.size(0), :]

                            else:
                                p.data = send_p.data[:, :p.size(1)]


                    elapsed = time.time() - start_time
                    lr = self.optimizer_outer.param_groups[0]['lr']
                    log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {idx + 1:>6d} batches | ' \
                              f'lr {lr:.3g} | time {elapsed:5.2f}s | ' \
                              f'loss {avg_loss_val:5.2f} | avg loss {avg_loss:5.2f} | ' \
                              f'ppl {math.exp(avg_loss):5.2f}'
                    if self.args.rank == 0:
                        print(log_str)
                        loss_list.append(avg_lm_loss.avg)
                    avg_lm_loss.reset()
                    distributed_sync(self.args)
                    round += 1
                    # if batch_idx % 20 == 0:
                    self.start_local_steps()
                if round >= self.args.com_rounds:
                    break
            if round >= self.args.com_rounds:
                break

            # Add a synchronization barrier before dist.all_gather
        distributed_sync(self.args)
        self.average_model()
        results = self.evaluate(valid_loader, valid_data_text)
        distributed_sync(self.args)
        gathered_score = distributed_gather(self.args, torch.tensor(results).to(self.args.device))
        if self.args.rank == 0:
            avg_score = gathered_score.mean(dim=0)
            if self.args.ref_type == 'webnlg':
                metric_names = ['BLEU', 'MET', 'TER', 'ROUGE_L']
            else:
                metric_names = ['BLEU', 'NIST', 'METEOR', 'ROUGE_L', 'CIDEr']
            print('The average metrics: \n')
            for i, metric in enumerate(metric_names):
                print('%s: %.4f' % (metric, avg_score[i]))
            log_str = f'| Eval {train_step // self.args.eval_interval:3d} at step {train_step:>8d} | ' \
                      f'time: {time.time() - start_time:5.2f}s | evaluation result: {avg_score}'
        distributed_sync(self.args)

        if self.args.rank == 0:
            print('-' * 100)
            print(log_str)
            print('-' * 100)
            torch.save(loss_list, self.args.save_path + '.pkl')
            with open(self.args.save_path + '.txt', 'w') as f:
                f.write(str({'Exp config': str(self.args), 'log_info': str(log_str)}))

    def average_model(self):
        norm_BA = [F.conv1d(a.unsqueeze(0), b.unsqueeze(-1), groups=2).squeeze(0).norm() \
                   for (b, a) in zip(self.trainable_params_B, self.trainable_params_A)]
        personal_norm_BA = copy.copy(norm_BA)
        # compute the sum of all the clients' BA norm
        for i in range(len(norm_BA)):
            self.args.dist.reduce(norm_BA[i], dst=0, op=self.args.dist.ReduceOp.SUM)
            self.args.dist.broadcast(norm_BA[i], src=0)
        i, j = 0, 0
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if 'lora_A' in n:
                    zero_tensor = torch.zeros(self.args.rank_max*2 - p.size(0), p.size(1), dtype=p.dtype,
                                              device=p.device)
                    send_p = torch.cat([p, zero_tensor], dim=0) / self.args.world_size
                    send_p *= personal_norm_BA[i] / norm_BA[i]
                    i += 1
                else:
                    # print('zero_B')
                    zero_tensor = torch.zeros(p.size(0), self.args.rank_max - p.size(1), dtype=p.dtype,
                                              device=p.device)
                    send_p = torch.cat([p, zero_tensor], dim=1) / self.args.world_size
                    send_p *= personal_norm_BA[j] / norm_BA[j]
                    j += 1

                self.args.dist.reduce(send_p.data, dst=0, op=self.args.dist.ReduceOp.SUM)
                self.args.dist.broadcast(send_p.data, src=0)

                if 'lora_A' in n:
                    p.data = send_p.data[:p.size(0), :]
                else:
                    p.data = send_p.data[:, :p.size(1)]

