import argparse
import time
import math
import os, sys

import torch
torch.set_printoptions(threshold=100000)

import numpy as np

from gpu import add_gpu_params, parse_gpu, distributed_opt, distributed_gather, distributed_sync, cleanup
from optimizer import create_adam_optimizer, create_optimizer_scheduler, add_optimizer_params, create_adam_optimizer_from_args


from data_utils import FT_Dataset # BinCorpus, BinLMOrderedIterator
from model import GPT2Config, GPT2LMModel, Attention
from exp_utils import create_exp_dir

import itertools

from torch.utils.data import DataLoader


parser = argparse.ArgumentParser(description='PyTorch GPT2 evaluation script.')

add_gpu_params(parser)
add_optimizer_params(parser)

parser.add_argument('--train_data', type=str, default='../data/wikitext-103',
                    help='location of training data corpus')

parser.add_argument('--valid_data', type=str, default='../data/wikitext-103',
                    help='location of validation data corpus')

parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')

parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')

parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumlate')

parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')

parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')

parser.add_argument('--model_card', default='gpt2.sm', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'], help='model names.')

parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint.')

parser.add_argument('--fp16', action='store_true', help='model fp16.')

parser.add_argument('--log_interval', type=int, default=100, help='log interval.')

parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval.')

parser.add_argument('--save_interval', type=int, default=500, help='save interval.')

parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'), help='working folder.')

parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension.')

parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha.')

parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'], help='language model training objective.')

parser.add_argument('--lora_dropout', default=0.0, type=float, help='dropout probability for lora layers.')

parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing.')

parser.add_argument('--prefix_len', default=0, type=int, help='prefix length.')

parser.add_argument('--infix_len', default=0, type=int, help='infix length.')

parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval.')

parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate.')

parser.add_argument('--roll_step', type=int, default=100, help='rolling step.')

parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs.')

parser.add_argument('--compress_step', type=int, default=1000, help='compress_step')


# influence model, calculate the influence score between two samples.
def print_args(args):
  if args.rank == 0:
    print('=' * 100)
    for k, v in args.__dict__.items():
      print('    - {} : {}'.format(k, v))
    print('=' * 100)

import random
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

class AverageMeter(object):
  """Computes and stores the average and current value
     Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
  """
  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update = True):
  if args.fp16:
    with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
      _scaled_loss.backward()
  else:
    _loss.backward()

  if is_update:
    if args.clip > 0:
      if args.fp16:
        torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
      else:
        torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)

    _optimizer.step()    
    _optimizer.zero_grad()

  if _schedule is not None:
    _schedule.step()

def evaluate(model, valid_loader, args):
  model.eval()
  total_loss = 0.
  start_time = time.time()

  avg_lm_loss = AverageMeter()

  with torch.no_grad():
    for idx, data in enumerate(valid_loader):
      data = {key: value for key, value in data.items()}

      _input = data['input'].to(args.device)
      _target = data['target'].to(args.device)
      _msk = data['mask'].to(args.device)

      _lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk) 
      loss = _loss.mean() 
      
      avg_lm_loss.update(loss.item())

      if idx % 100 == 0:
        print('eval samples:', idx, 'loss:', loss.float())

    total_time = time.time() - start_time
    print('average loss', avg_lm_loss.avg)
  return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)


def train_validate(model, optimizer, scheduler, train_loader, valid_loader, args, train_step = 0, epoch = 0):
  model.train()
  avg_lm_loss = AverageMeter()
  print('start to train the model................', epoch)
  log_start_time = time.time()
  best_val_ppl = None

  train_loader.sampler.set_epoch(epoch)

  for idx, data in enumerate(train_loader):
    data = {key: value for key, value in data.items()}

    _input = data['input'].to(args.device)
    _target = data['target'].to(args.device)
    _msk = data['mask'].to(args.device)
    _lm_logits, _lm_loss = model(_input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth) 

    _lm_loss = _lm_loss.mean() 

    train_step += 1
    is_update = True if train_step % args.grad_acc == 0 else False
    avg_lm_loss.update(_lm_loss.item())
    optimizer_step(_lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update)
    
    if train_step % args.log_interval == 0: 
      elapsed = time.time() - log_start_time

      log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g}' \
                '| ms/batch {:5.2f} | loss {:5.2f} | avg loss {:5.2f} | ppl {:5.2f}'.format(
                epoch, train_step, idx + 1, optimizer.param_groups[0]['lr'], 
                elapsed * 1000 / args.log_interval, avg_lm_loss.val, avg_lm_loss.avg, math.exp(avg_lm_loss.avg)) 

      if args.rank == 0: 
        print(log_str)
      log_start_time = time.time()
      avg_lm_loss.reset()
    
    if train_step % args.save_interval == 0: 
      if args.rank == 0:
        model_path = os.path.join(args.work_dir, 'model.'+str(train_step)+'.pt')
        print('saving checkpoint', model_path)
        torch.save({'model_state_dict': model.state_dict()}, model_path)
      distributed_sync(args)

    # evaluation interval
    if train_step % args.eval_interval == 0:
      eval_start_time = time.time()

      valid_loss, valid_ppl = evaluate(model, valid_loader, args)

      if best_val_ppl is None or valid_ppl < best_val_ppl:
        best_val_ppl = valid_ppl
        
      log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
                '| valid loss {:5.2f} | valid ppl {:5.2f} | best ppl {:5.2f} '.format(
                train_step // args.eval_interval, train_step,
                (time.time() - eval_start_time), valid_loss, valid_ppl, best_val_ppl)

      if args.rank == 0:
        print('-' * 100)
        print(log_str)
        print('-' * 100)

      model.train()
      distributed_sync(args)

    if train_step == args.max_step:
      break

  if args.rank == 0:
    model_path = os.path.join(args.work_dir, 'model.'+str(train_step)+'.pt')
    print('saving checkpoint', model_path)
    torch.save({'model_state_dict': model.state_dict()}, model_path) 
  distributed_sync(args)
  return train_step
if __name__ == '__main__':
  args = parser.parse_args()
  set_seed(args.random_seed)
  parse_gpu(args)
  print_args(args)
  
  if args.rank == 0:
    args.logging = create_exp_dir(args.work_dir)

  train_data =  FT_Dataset(args.train_data, args.train_batch_size, args.seq_len, joint_lm=args.obj=='jlm', prefix_len=args.prefix_len, infix_len=args.infix_len)   
  
  valid_data = FT_Dataset(args.valid_data, args.valid_batch_size, args.seq_len, prefix_len=args.prefix_len, infix_len=args.infix_len)

  train_loader = DataLoader(train_data, batch_size=args.train_batch_size, num_workers=0, shuffle=False, pin_memory=False, drop_last=True,
                            sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed))
  
  valid_loader = DataLoader(valid_data, batch_size=args.valid_batch_size, num_workers=0, shuffle=False, pin_memory=False, drop_last=False,
                           sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed))

  if args.model_card == 'gpt2.sm':
    config = GPT2Config(n_embd=768, n_layer=12, n_head=12, lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha, lora_dropout=args.lora_dropout,
                        prefix_len=args.prefix_len, infix_len=args.infix_len)
  elif args.model_card == 'gpt2.md':
    config = GPT2Config(n_embd=1024, n_layer=24, n_head=16, lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha, lora_dropout=args.lora_dropout,
                        prefix_len=args.prefix_len, infix_len=args.infix_len)
  elif args.model_card == 'gpt2.lg':
    config = GPT2Config(n_embd=1280, n_layer=36, n_head=20, lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha, lora_dropout=args.lora_dropout,
                        prefix_len=args.prefix_len, infix_len=args.infix_len)

  lm_net = GPT2LMModel(config)
  if args.init_checkpoint is not None:
    print('loading model pretrained weight.')
    lm_net.load_weight(torch.load(args.init_checkpoint))  

  lm_net = lm_net.cuda()

  if args.lora_dim == 0:
    optimizer = create_adam_optimizer_from_args(lm_net, args)
    # create_adam_optimizer(lm_net, args.lr, args.weight_decay, correct_bias=True, adam_epislon=1.0e-6, no_decay_bias=args.no_decay_bias)
  else:
    for n, p in lm_net.named_parameters():
      if 'adapter' in n:
        print(f'{n}, shape: {p.shape}')
      else:
        p.requires_grad = False

    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in lm_net.named_parameters() if 'adapter' in n or 'embedding' in n],
        }
    ]
    optimizer = create_adam_optimizer_from_args(None, args, grouped_parameters=optimizer_grouped_parameters)
    #None, args.lr, args.weight_decay, optimizer_grouped_parameters=optimizer_grouped_parameters, correct_bias=True, adam_epislon=1.0e-6)

  if args.max_step is None:
    args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
    print('set max_step:', args.max_step)

  scheduler = create_optimizer_scheduler(optimizer, args)
  if args.fp16:
    lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
  lm_net, optimizer = distributed_opt(args, lm_net, optimizer, grad_acc=args.grad_acc)

  for name, module in lm_net.named_modules():
    if isinstance(module, Attention):
      module.S_Q.data = torch.zeros(1024, 1024).to(module.S_Q.device)
      module.S_V.data = torch.zeros(1024, 1024).to(module.S_Q.device)

  S_V = []
  for _ in range(args.compress_step):
    S_V_epoch = []
    for name, module in lm_net.named_modules():
      if isinstance(module, Attention):
        Q_weight = module.c_attn.weight[:, :module.split_size]
        V_weight = module.c_attn.weight[:, 2*module.split_size:]

        U_Q = torch.qr((Q_weight - module.S_Q.data) @ module.q_proj_adapter1.weight.data.T)[0]
        module.q_proj_adapter2.weight.data = U_Q
        V_Q = U_Q.T @ (Q_weight - module.S_Q.data)
        module.q_proj_adapter1.weight.data = V_Q
        module.S_Q.data = -U_Q @ V_Q

        q, _ = torch.kthvalue(module.S_Q.data.abs().view(-1), module.S_Q.data.numel() - 128)
        module.S_Q.data[module.S_Q.data.abs() < q] = 0

        U_V = torch.qr((V_weight - module.S_V.data) @ module.v_proj_adapter1.weight.data.T)[0]
        module.v_proj_adapter2.weight.data = U_V
        V_V = U_V.T @ (Q_weight - module.S_V.data)
        module.v_proj_adapter1.weight.data = V_V
        module.S_V.data = -U_V @ V_V

        v, _ = torch.kthvalue(module.S_V.data.abs().view(-1), module.S_V.data.numel() - 128)
        module.S_V.data[module.S_V.data.abs() < v] = 0
        S_V_epoch.append(module.S_V.data.detach().cpu().numpy())
    
    S_V.append(S_V_epoch[0])
  S_V = np.stack(S_V)
  np.save("S_V_0.npz", S_V)

  import torch.nn as nn
  for name, module in lm_net.named_modules():
      if isinstance(module, Attention):
        nn.init.normal_(module.q_proj_adapter1.weight, std=0.02)
        module.q_proj_adapter2.weight.data.zero_()
        nn.init.normal_(module.v_proj_adapter1.weight, std=0.02)
        module.v_proj_adapter2.weight.data.zero_()
        module.S_V.data = (module.S_V.data.abs() > 0).float()
        module.S_Q.data = (module.S_Q.data.abs() > 0).float()
        assert abs(module.S_V.data.sum() - 128) < 3
        assert abs(module.S_V.data.sum() - 128) < 3


  distributed_sync(args)
  print('cleanup dist ...')
  cleanup(args)

