#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import argparse
import time
import math
import os, sys
import numpy as np
import itertools
import functools
import pickle

import torch
import random
from torch.utils.data import DataLoader
torch.set_printoptions(threshold=100000)

from gpu import (
    add_gpu_params, 
    parse_gpu, 
    distributed_opt, 
    distributed_gather, 
    distributed_sync, 
    cleanup
)
from optimizer import (
    create_adam_optimizer, 
    create_optimizer_scheduler, 
    add_optimizer_params, 
    create_adam_optimizer_from_args
)

from data_utils import FT_Dataset
from model import GPT2Config, GPT2LMModel
from exp_utils import create_exp_dir

import loralib as lora

parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')

add_gpu_params(parser)
add_optimizer_params(parser)

parser.add_argument('--train_data', required=True, help='location of training data corpus')

parser.add_argument('--valid_data', required=True, help='location of validation data corpus')

parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')

parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')

parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')

parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')

parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')

parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'], 
                    help='model names')

parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')

parser.add_argument('--fp16', action='store_true', help='train model with fp16')

parser.add_argument('--log_interval', type=int, default=100, help='log interval')

parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')

parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'), 
                    help='working folder.')

parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')

parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')

parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'], 
                    help='language model training objective')

parser.add_argument('--lora_dropout', default=0.0, type=float, 
                    help='dropout probability for lora layers')

parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')

parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')

parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')

parser.add_argument('--roll_step', type=int, default=100, help='rolling step')

parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')

parser.add_argument('--lora_path', default=None, help="The file path of LoRA parameters.")

parser.add_argument('--lora_kind', type=str, default='LoRA', help="One of [LoRA, VeRA, DoRA]")

parser.add_argument('--adaptive_ranks', default=False, action='store_true',
                    help = "Redistribute ranks according to explained variances of PCA."
                    )

parser.add_argument('--redist_metric', type=str, default='ratio', help = "One of ratio, raw, sum, or max")

parser.add_argument('--exp_var_threshold', type=float, default=0.9, help = "explained variance threshold for adaptive ranks")

parser.add_argument('--whiten_pca', default=False, action='store_true',
                    help='Whether to whiten PCA so that components have unit variance.'
                    )

# influence model, calculate the influence score between two samples.
def print_args(args):
    if args.rank == 0:
        print('=' * 100)
        for k, v in args.__dict__.items():
            print(f'        - {k} : {v}')
        print('=' * 100)


class AverageMeter(object):
    """Computes and stores the average and current value
         Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update=True):
    if args.fp16:
        with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
            _scaled_loss.backward()
    else:
        _loss.backward()

    if is_update:
        if args.clip > 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)

        _optimizer.step()        
        _optimizer.zero_grad()

    if _schedule is not None:
        _schedule.step()


def evaluate(model, valid_loader, args):
    model.eval()
    total_loss = 0.
    start_time = time.time()

    avg_lm_loss = AverageMeter()

    with torch.no_grad():
        for idx, data in enumerate(valid_loader):
            data = {key: value for key, value in data.items()}

            _input = data['input'].to(args.device)
            _target = data['target'].to(args.device)
            _msk = data['mask'].to(args.device)

            _lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk) 
            loss = _loss.mean() 
            
            avg_lm_loss.update(loss.item())

            if idx % 100 == 0:
                print('eval samples:', idx, 'loss:', loss.float())

        total_time = time.time() - start_time
        print('average loss', avg_lm_loss.avg)
    return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)


def train_validate(
    model, 
    optimizer, 
    scheduler, 
    train_loader, 
    valid_loader, 
    args, 
    train_step=0, 
    epoch=0,
    best_val_ppl=None
):
    model.train()
    avg_lm_loss = AverageMeter()
    print('start to train the model................', epoch)
    log_start_time = time.time()

    train_loader.sampler.set_epoch(epoch)

    for idx, data in enumerate(train_loader):
        data = {key: value for key, value in data.items()}

        _input = data['input'].to(args.device)
        _target = data['target'].to(args.device)
        _msk = data['mask'].to(args.device)

        _lm_logits, _lm_loss = model(
            _input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth
        ) 

        _lm_loss = _lm_loss.mean() 

        train_step += 1
        is_update = True if train_step % args.grad_acc == 0 else False
        avg_lm_loss.update(_lm_loss.item())
        optimizer_step(
            _lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update
        )
        
        if train_step % args.log_interval == 0: 
            elapsed = time.time() - log_start_time
            lr = optimizer.param_groups[0]['lr']
            log_str = f'| epoch {epoch:3d} step {train_step:>8d} | { idx + 1:>6d} batches | ' \
                      f'lr {lr:.3g} | ms/batch {elapsed * 1000 / args.log_interval:5.2f} | ' \
                      f'loss {avg_lm_loss.val:5.2f} | avg loss {avg_lm_loss.avg:5.2f} | ' \
                      f'ppl {math.exp(avg_lm_loss.avg):5.2f}'
            
            args.logging(log_str)

            if args.rank == 0: 
                print(log_str)
            log_start_time = time.time()
            avg_lm_loss.reset()

        # evaluation interval
        if (train_step % args.eval_interval == 0) or (train_step == args.max_step):
            eval_start_time = time.time()

            valid_loss, valid_ppl = evaluate(model, valid_loader, args)

            if best_val_ppl is None or valid_ppl < best_val_ppl:
                best_val_ppl = valid_ppl
                current_is_best = True
            else:
                current_is_best = False
                
            log_str = f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' \
                      f'time: {time.time() - eval_start_time:5.2f}s | valid loss {valid_loss:5.2f} | ' \
                      f'valid ppl {valid_ppl:5.2f} | best ppl {best_val_ppl:5.2f}'
            
            args.logging(log_str)

            if args.rank == 0:
                print('-' * 100)
                print(log_str)
                print('-' * 100)

            if args.rank == 0 and current_is_best:
                best_model_path = os.path.join(args.work_dir, f'best_model.pt')
                print('saving best checkpoint', best_model_path)
                torch.save({'model_state_dict': model.state_dict()}, best_model_path)
                with open(os.path.join(args.work_dir, f'best_train_step.txt'), "w") as f:
                    f.write(str(train_step))

            model.train()
            distributed_sync(args)

        if train_step == args.max_step:
            break

    distributed_sync(args)
    return train_step, best_val_ppl


if __name__ == '__main__':

    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '49152'
    # os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
    # os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
    # os.environ['OMPI_COMM_WORLD_RANK'] = '0'
    # os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    args = parser.parse_args()
    parse_gpu(args)
    print_args(args)

    if args.fp16:
        try:
            from apex import amp
        except Exception as e:
            warnings.warn('Could not import amp, apex may not be installed')

    torch.manual_seed(args.random_seed)
    random.seed(args.random_seed)

    if not os.path.exists(args.work_dir):
        os.makedirs(args.work_dir)
    
    if args.rank == 0:
        args.logging = create_exp_dir(args.work_dir)

    train_data = FT_Dataset(
        args.train_data, args.train_batch_size, args.seq_len, 
        joint_lm=args.obj=='jlm'
    )     
    
    valid_data = FT_Dataset(
        args.valid_data, args.valid_batch_size, args.seq_len,
    )

    train_loader = DataLoader(
        train_data, batch_size=args.train_batch_size, num_workers=0, 
        shuffle=False, pin_memory=False, drop_last=True,
        sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed)
    )
    
    valid_loader = DataLoader(
        valid_data, batch_size=args.valid_batch_size, num_workers=0, 
        shuffle=False, pin_memory=False, drop_last=False,
        sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
    )

    enable_lora_attn = [True, True, True, True] # Q,K,V,out_proj
    enable_lora_mlp = True
    enable_lora_head = False

    if args.model_card == 'gpt2.sm':
        config = GPT2Config(
            n_embd=768, n_layer=12, n_head=12, 
            lora_dim=args.lora_dim, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            enable_lora_attn=enable_lora_attn,
            enable_lora_mlp=enable_lora_mlp,
            enable_lora_head=enable_lora_head
        )
    elif args.model_card == 'gpt2.md':
        config = GPT2Config(
            n_embd=1024, n_layer=24, n_head=16, 
            lora_dim=args.lora_dim, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            enable_lora_attn=enable_lora_attn,
            enable_lora_mlp=enable_lora_mlp,
            enable_lora_head=enable_lora_head
        )
    elif args.model_card == 'gpt2.lg':
        config = GPT2Config(
            n_embd=1280, n_layer=36, n_head=20, 
            lora_dim=args.lora_dim, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            enable_lora_attn=enable_lora_attn,
            enable_lora_mlp=enable_lora_mlp,
            enable_lora_head=enable_lora_head
        )

    lm_net = GPT2LMModel(config)

    trainable_params = []
    if args.lora_kind in ["LoRA", "VeRA", "DoRA"]:
        if args.lora_path is not None:
            lora_state_dict = torch.load(args.lora_path)
            # fix for differing naming
            lora_state_dict = {("lm_head.decoder." + k.split('.')[-1] if "lm_head" in k else k):v for k,v in lora_state_dict.items()}
            ###
            lora_modules = ['.'.join(n.split('.')[:-1]) for n in lm_net.state_dict() if 'lora_A' in n]
            lora_state_dict = {k: v for k,v in lora_state_dict.items() if '.'.join(k.split('.')[:-1]) in lora_modules}
            exp_vars = pickle.load(open(args.lora_path.replace(".bin", ".pkl"), "rb"))
            exp_vars = {k + ".lora_A" : v for k,v in exp_vars.items() if k in lora_modules}
            new_state_dict = {}
            if args.adaptive_ranks:
                assert args.redist_metric in ["ratio", "raw", "sum", "max"], "redist metric must be either of raw, ratio, sum, max"
                if not os.path.exists(args.lora_path.replace(".bin", ".pkl")):
                    raise FileNotFoundError("No explained variances found, re-run pre_compute_init.py with --pca_on_acts flag!")
                exp_vars = {k: exp_vars[k][args.redist_metric] for k in exp_vars.keys()}
                new_state_dict = lora.redistribute_ranks(lora_state_dict, exp_vars, args.exp_var_threshold, args.lora_dim, from_scratch=True)
                for key in lora_state_dict.keys():
                    if 'lora_A' in key:
                        module = functools.reduce(lambda x,y: getattr(x, y), key.split('.')[:-1], lm_net)
                        if key not in new_state_dict: # rank = 0
                            module.change_lora_rank(0)
                        elif new_state_dict[key].shape[0] != args.lora_dim:
                            # re-initialize lora module
                            new_rank = new_state_dict[key].shape[0]
                            if args.whiten_pca:
                                device = new_state_dict[key].device
                                denom = torch.from_numpy(exp_vars[key]).sqrt().unsqueeze(-1).to(device)
                                new_state_dict[key] /= denom[:new_rank, :]
                            module.change_lora_rank(new_rank)
                        if isinstance(module, lora.MergedLinear) and key in new_state_dict:
                            new_state_dict[key] = torch.cat([new_state_dict[key]] * sum(module.enable_lora), dim = 0)
                # TODO
                # plot_rank_heatmap(lora_state_dict, new_state_dict, model_args, data_args, config, kind=model_args.redist_metric)
            else:
                # adapt rank of lora matrices according to model_args.lora_r
                for key in lora_state_dict.keys():
                    if 'lora_A' in key:
                        if not args.whiten_pca:
                            new_state_dict[key] = lora_state_dict[key][:args.lora_dim, :]
                        else:
                            device = lora_state_dict[key].device
                            denom = torch.from_numpy(exp_vars[key]['raw']).sqrt().unsqueeze(-1).to(device)
                            new_state_dict[key] = lora_state_dict[key][:args.lora_dim, :] / denom[:args.lora_dim, :]
                        module = functools.reduce(lambda x,y: getattr(x, y), key.split('.')[:-1], lm_net)
                        if isinstance(module, lora.MergedLinear):
                            new_state_dict[key] = torch.cat([new_state_dict[key]] * sum(module.enable_lora), dim = 0)
            lm_net.load_state_dict(new_state_dict, strict=False)
            del new_state_dict
            del lora_state_dict
            torch.cuda.empty_cache()
        else:
            lm_net.apply(lora.init_scaling)

    if args.init_checkpoint is not None:
        print('loading model pretrained weight.')
        lm_net.load_weight(torch.load(args.init_checkpoint))

    lm_net = lm_net.cuda()

    if args.lora_dim > 0:
        lora.mark_only_lora_as_trainable(lm_net)

    optimizer = create_adam_optimizer_from_args(lm_net, args)

    if args.max_step is None:
        args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
        print('set max_step:', args.max_step)

    scheduler = create_optimizer_scheduler(optimizer, args)
    if args.fp16:
        lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
    lm_net, optimizer = distributed_opt(args, lm_net, optimizer, grad_acc=args.grad_acc)

    try:
        train_step = 0
        best_val_ppl = None
        for epoch in itertools.count(start=1):
            train_step, best_val_ppl = train_validate(
                lm_net, optimizer, scheduler, train_loader, valid_loader, args, 
                train_step=train_step, epoch=epoch, best_val_ppl=best_val_ppl
            )
            
            if train_step >= args.max_step or (args.max_epoch is not None and epoch >= args.max_epoch):
                if args.rank == 0:
                    print('-' * 100)
                    print('End of training')
                break
    except KeyboardInterrupt:
        if args.rank == 0:
            print('-' * 100)
            print('Exiting from training early')

    distributed_sync(args)
    print('cleanup dist ...')
    cleanup(args)
