from __future__ import print_function
from parser import get_parser
import torch
import os
import sys
import math
import pdb
import random
from tqdm import tqdm
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib

matplotlib.use('agg')
import matplotlib.pyplot as plt
from util import *
from models import *
from loader import *
import time

from matplotlib import pyplot
from scipy.stats import pearsonr


args = get_parser()
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
device = torch.device("cuda:0")
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print(args)


'''Prepare result path'''
args.file_dir = os.path.dirname(os.path.realpath('__file__'))
if args.debug:
  args.save_dir = './results/debug'
elif args.save_dir is None:
  args.save_dir = './results'
  args.save_dir += '/447_4152'
  args.save_dir += '/time'
  # args.save_dir += '/447_1232'
  args.save_dir += f'/it-{args.input_type}'
  args.save_dir += f'_st-{args.setenc_type}'
  args.save_dir += f'_bi-{args.bidirectional}'
  args.save_dir += f'_nd-{args.pp_abl_nd}'
  # args.save_dir += f'_nz-{args.nz}'
  # args.save_dir += f'_lr-{args.pred_lr}'
  if args.pretrain:
    args.save_dir += f'/pretrain'
  else:
    args.save_dir += f'/pt-True' if args.load_pretrain else f'/pt-False'
args.g_enc_lr = args.d_enc_lr = args.pred_lr
model_dir = os.path.join(args.save_dir, 'model')

if not os.path.exists(args.save_dir):
  os.makedirs(model_dir)
  os.makedirs(args.save_dir+'/pred')

_, _, G_args = load_NAS201_graphs(args.data_name, n_types=args.nvt, cn=args.cn)


if args.pretrain:
  dname_lst = ['ImageNet16-120']
  train_data = PretrainDataset(args, mode='tr', dname_lst=dname_lst)
  val_data = PretrainDataset(args, mode='va', dname_lst=dname_lst)
elif args.test:
  pass
else:
  train_data = MetaD2A(args, mode='tr')
  val_data = MetaD2A(args, mode='va')

'''Prepare the model'''
# model
model = eval(args.model)(args, G_args)

G_enc_lst, D_enc_lst, pred_lst = [], [], []
for name, param in model.named_parameters():
  if ('set_trans' in name) or ('set_fc' in name):
    D_enc_lst.append(param)
  elif 'pred' in name:
    pred_lst.append(param)
  else:
    G_enc_lst.append(param)

params = [{'params': D_enc_lst, 'lr': args.d_enc_lr},
          {'params': G_enc_lst, 'lr': args.g_enc_lr},
          {'params': pred_lst, 'lr': args.pred_lr}]
optimizer = optim.Adam(params, lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10, verbose=True)

model.to(device)

if args.load_pretrain:
  load_dir = arg.save_dir.replace('pt-True', 'pretrain')
  load_module_state(model, load_dir+'/model/model_checkpoint100.pth')
  print("Model has been loaded")



def train(epoch):
  model.train()
  train_loss = 0

  train_idx = torch.randperm(len(train_data))
  pbar = tqdm(train_idx)
  g_batch = []
  y_batch = []
  x_batch = []
  dname_batch = []
  y_all =[]
  y_pred_all = []

  
  for i, ridx in enumerate(pbar):
    x, g, y, dname = train_data[ridx]
    g_batch.append(g)
    y_batch.append(y)
    x_batch.append(x)
    dname_batch.append(dname)
    max_img = train_data.max_img
    if len(g_batch) == args.batch_size or i == len(train_data) - 1:
      optimizer.zero_grad()
      g_batch = model._collate_fn(g_batch)
      x_batch = torch.stack(x_batch).cuda()  # [32, 400, 512]
      D_mu = model.set_encode(x_batch, max_img, dname_batch)
      G_mu = model.graph_encode(g_batch)
      y_batch = torch.FloatTensor(y_batch).unsqueeze(1).to(device)
      y_pred = model.predictor(D_mu, G_mu)
      loss = model.mseloss(y_pred, y_batch)

      y_batch = y_batch.cpu().squeeze().detach().tolist()
      y_pred = y_pred.cpu().squeeze().detach().tolist()
      y_all += y_batch
      y_pred_all += y_pred
      msg = f'Epoch {epoch}, loss {loss.item()/len(g_batch):0.4f} '
      msg += f'pacc {y_pred[0]:0.4f}({y_pred[0]*100.0*train_data.std+train_data.mean:0.4f}) ' + \
              f'acc {y_batch[0]:0.4f}({y_batch[0]*100.0*train_data.std+train_data.mean:0.4f})'
      pbar.set_description(msg)
      loss.backward()
      
      train_loss += float(loss)
      optimizer.step()
      g_batch = []
      y_batch = []
      x_batch = []
      dname_batch = []

      if args.pretrain:
        if train_data.mode =='tr':
          n = 1 + torch.randperm(10)[0].item()
          train_data.max_img = {}
          for dname in dname_lst:
            train_data.max_img[dname] = n

  y_all = np.array(y_all)
  y_pred_all = np.array(y_pred_all)
  # calculate Pearson's correlation
  corr, _ = pearsonr(y_all, y_pred_all)       
  avg_loss = train_loss / len(train_data)
  print(f'====> Epoch: {epoch} Avg loss: {avg_loss:.4f} Pearons corr {corr:.3f}')
  return avg_loss, corr
  

def val(epoch):
  ''' Compute correlation between acc and predict accuracy of unseen MetaD2A tasks '''
  model.eval()
  valoss = 0

  idx = torch.randperm(len(val_data))
  pbar = tqdm(idx)
  g_batch = []
  y_batch = []
  x_batch = []
  dname_batch = []
  y_all =[]
  y_pred_all = []
  mean = train_data.mean
  std = train_data.std
  
  with torch.no_grad():
    for i, ridx in enumerate(pbar):
      x, g, y, dname = val_data[ridx]
      g_batch.append(g)
      y_batch.append(y)
      x_batch.append(x)
      dname_batch.append(dname)
      max_img = val_data.max_img
      if len(g_batch) == args.batch_size or i == len(val_data) - 1:
        g_batch = model._collate_fn(g_batch)
        x_batch = torch.stack(x_batch).cuda()  # [32, 400, 512]
        D_mu = model.set_encode(x_batch, max_img, dname_batch)
        G_mu = model.graph_encode(g_batch)
        y_batch = torch.FloatTensor(y_batch).unsqueeze(1).to(device)
        y_pred = model.predictor(D_mu, G_mu)
        loss = model.mseloss(y_pred, y_batch)

        y_batch = y_batch.cpu().squeeze().detach().tolist()
        y_pred = y_pred.cpu().squeeze().detach().tolist()
        y_all += y_batch
        y_pred_all += y_pred
        msg = f'[Val] Epoch {epoch}, loss {loss.item()/len(g_batch):0.4f}'
        msg += f' pacc {y_pred[0]:0.4f}({y_pred[0]*100.0*std+mean:0.4f}) ' + \
                f'acc {y_batch[0]:0.4f}({y_batch[0]*100.0*std+mean:0.4f})'
        pbar.set_description(msg)
        
        valoss += float(loss)
        g_batch = []
        y_batch = []
        x_batch = []
        dname_batch = []

  y_all = np.array(y_all)
  y_pred_all = np.array(y_pred_all)
  # calculate Pearson's correlation
  corr, _ = pearsonr(y_all, y_pred_all)       
  avg_loss = valoss / len(val_data)
  print(f'====> [Val] Epoch: {epoch} Avg loss: {avg_loss:.4f} Pearons corr {corr:.3f}')
  
  pyplot.scatter(y_all, y_pred_all)
  pyplot.title(f'corr {corr:.3f}')
  pyplot.savefig(os.path.join(args.save_dir, 'pred', f'ep_{epoch}_corr_norm.png'))
  pyplot.clf()
  pyplot.title(f'corr {corr:.3f}')
  pyplot.scatter(y_all*std+mean, y_pred_all*std+mean)
  pyplot.savefig(os.path.join(args.save_dir, 'pred', f'ep_{epoch}_corr.png'))
  return avg_loss, corr



def test(epoch, dname, test_data):
  if not dname in ['cifar10', 'cifar100']:
    raise NotImplementedError

  if args.smp_type == 1: # random
    nasbench201 = torch.load('/w14/dataset/MetaD2A/predictor/nasbench201.pt')
    test_acc_lst = nasbench201[dname]['test-accuracy']
    ridx = torch.randperm(len(test_acc_lst))
    test_acc_lst = test_acc_lst[ridx]
    for top in [10, 20, 50, 100]:
      final_acc = torch.max(test_acc_lst[:top])
      print(f'Top {top} acc {final_acc:.4f}')

  elif args.smp_type == 2: # random + performance predictor
    nasbench201 = torch.load('/w14/dataset/MetaD2A/predictor/nasbench201.pt')
    test_acc_lst = nasbench201[dname]['test-accuracy']
    g_str_lst = nasbench201[dname]['nas201_8_str2']
    g_lst = nasbench201[dname]['nas201_8_igraph']

    ridx_lst = torch.randperm(len(g_lst))[:1000]
    test_acc_lst = test_acc_lst[ridx_lst]
    g_str_lst = g_str_lst[ridx_lst]
    g_lst = g_lst[ridx_lst]

    import pdb; pdb.set_trace()

    x_batch, g_batch, dname_batch = [], [], []
    y_pred_all = []

    with torch.no_grad():
      for i, g in enumerate(g_lst):
        model.eval()
        x_batch.append(test_data[0])
        g_batch.append(g)
        dname_batch.append(dname)
        if len(g_batch) == args.batch_size or i == len(test_data) - 1:
          g_batch = model._collate_fn(g_batch)
          x_batch = torch.stack(x_batch).cuda()  # [32, 400, 512]

          D_mu = model.set_encode(x_batch, max_img, dname_batch)
          G_mu = model.graph_encode(g_batch)
          y_pred = model.predictor(D_mu, G_mu)
          y_pred_all += y_pred.cpu().squeeze().detach().tolist()

          x_batch, g_batch, dname_batch = [], [], []

    if len(y_pred_all) != 1000:
      raise ValueError(len(y_pred_all))

    y_pred_all, sidx = torch.sort(torch.tensor(y_pred_all), descending=True)
    test_acc_lst = test_acc_lst[sidx]
    g_str_lst = g_str_lst[sidx]
    for top in [10, 20, 50, 100]:
      final_acc = torch.max(test_acc_lst[:top])
      print(f'Top {top} acc {final_acc:.4f}')



def main():
  sttime = time.time()
  for epoch in range(1, args.epochs + 1):
    ep_sttime = time.time()
    loss, corr = train(epoch)
    scheduler.step(loss)

    if 'time' in args.save_dir:
      if epoch % 70 == 0:
        with open(os.path.join(args.save_dir, 'train_loss.txt'), 'a') as f:
          msg = f"ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} time {time.time() - sttime:6.2f} "
          msg += f"loss {loss:.2f} corr {corr:.4f}\n"
          f.write(msg)      
    else:
      with open(os.path.join(args.save_dir, 'train_loss.txt'), 'a') as f:
        msg = f"ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} time {time.time() - sttime:6.2f} "
        msg += f"loss {loss:.2f} corr {corr:.4f}\n"
        f.write(msg)
        
      if epoch % args.save_interval == 0:
        with open(os.path.join(args.save_dir, 'val_loss.txt'), 'a') as f:
          valoss, va_corr = val(epoch)
          msg = f"ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} time {time.time() - sttime:6.2f} "
          msg += f"loss {valoss:.2f} corr {va_corr:.4f}\n"
          f.write(msg)
        print("save current model...")
        torch.save(model.state_dict(), os.path.join(model_dir, f'model_checkpoint{epoch}.pth'))


if __name__ == '__main__':
  main()
