from __future__ import print_function
from parser import get_parser
import torch
import os
import sys
import math
import pickle
import pdb
import random
from tqdm import tqdm
from shutil import copy
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib

matplotlib.use('agg')
import matplotlib.pyplot as plt
from util import *
from models import *
from loader import *
import time

from matplotlib import pyplot
from scipy.stats import pearsonr


args = get_parser()
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
device = torch.device("cuda:0")
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print(args)


'''Prepare result path'''
args.file_dir = os.path.dirname(os.path.realpath('__file__'))
if args.debug:
  args.save_dir = './results/debug'
elif args.save_dir is None:
  args.save_dir = './results'
  # args.save_dir += '/447_4152'
  args.save_dir += '/447_1232'
  #args.save_dir += '/time'
  args.save_dir += f'/st-{args.setenc_type}'

if 'time' in args.save_dir:
  model_dir = os.path.join(args.save_dir.replace('/time', ''), 'model')
else:
  model_dir = os.path.join(args.save_dir, 'model')

if not os.path.exists(args.save_dir):
  os.makedirs(model_dir)
  os.makedirs(args.save_dir+'/samp_arch')

_, _, G_args = load_NAS201_graphs(args.data_name, n_types=args.nvt, cn=args.cn)


if args.pretrain:
  train_data = PretrainDataset(args, mode='tr')
  val_data = PretrainDataset(args, mode='va')
elif args.test:
  pass
elif args.rlv:
  if not os.path.exists(args.save_dir+'/samp_arch_rlv'):
    os.makedirs(args.save_dir+'/samp_arch_rlv')
else:
  train_data = MetaD2A(args, mode='tr')
  val_data = MetaD2A(args, mode='va')

'''Prepare the model'''
# model
model = eval(args.model)(args, G_args)

# G_enc_lst, D_enc_lst, pred_lst = [], [], []
# for name, param in model.named_parameters():
#   if ('set_trans' in name) or ('set_fc' in name):
#     D_enc_lst.append(param)
#   elif 'pred' in name:
#     pred_lst.append(param)
#   else:
#     G_enc_lst.append(param)

# params = [{'params': D_enc_lst, 'lr': args.d_enc_lr},
#           {'params': G_enc_lst, 'lr': args.g_enc_lr}]
params = model.parameters()
optimizer = optim.Adam(params, lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10, verbose=True)

model.to(device)


def train(epoch):
  model.train()
  train_loss = 0
  recon_loss = 0
  kld_loss = 0

  train_idx = torch.randperm(len(train_data))
  pbar = tqdm(train_idx)
  g_batch = []
  x_batch = []
  dname_batch = []
  
  for i, ridx in enumerate(pbar):
    x, g, _, dname = train_data[ridx]
    g_batch.append(g)
    x_batch.append(x)
    dname_batch.append(dname)
    max_img = train_data.max_img
    if len(g_batch) == args.batch_size or i == len(train_data) - 1:
      optimizer.zero_grad()
      g_batch = model._collate_fn(g_batch)
      x_batch = torch.stack(x_batch).cuda()  # [32, 400, 512]
      if args.pretrain: # g2g
        mu, logvar = model.graph_encode(g_batch)
      else:
        mu, logvar = model.set_encode(x_batch, max_img, dname_batch)
      loss, recon, kld = model.loss(mu, logvar, g_batch)
      msg = f'Epoch {epoch} loss {loss.item() / len(g_batch):0.4f} '
      msg += f'recon {recon.item() / len(g_batch):0.4f} kld {kld.item() / len(g_batch):0.4f}'
      pbar.set_description(msg)
      loss.backward()
      
      train_loss += float(loss)
      recon_loss += float(recon)
      kld_loss += float(kld)
      optimizer.step()
      g_batch = []
      x_batch = []
      dname_batch = []

      if args.pretrain:
        n = 1 + torch.randperm(9)[0].item()
        train_data.max_img = {'cifar100': n,
                              'cifar10': n*10}     
  avg_loss = train_loss / len(train_data)
  print(f'====> Epoch: {epoch} Avg loss: {avg_loss:.4f}')
  return avg_loss, recon_loss, kld_loss


def infer_MetaD2A(dataset, epoch, tidx=None, dname=None):
  model.eval()
  x_batch, dname_batch = [], []
  nas_str_lst, gidx = [], []
  total_cnt, valid_cnt = 0, 0
  flag = False

  task_num = dataset.task[tidx]
  dpath = os.path.join(args.save_dir, 'samp_arch', f'{dname}')
  if not os.path.exists(dpath):
    os.makedirs(dpath)
    print(f'create {dpath}')

  start=time.time()
  with torch.no_grad():
    for _ in range(1000000):
      x_batch.append(dataset[tidx])
      dname_batch.append(dname)
      if len(x_batch) == 100:
        x_batch = torch.stack(x_batch).cuda()
        mu, logvar = model.set_encode(x_batch, dataset.max_img, dname_batch)
        z = model.reparameterize(mu, logvar)
        g_recon = model.graph_decode(z)
        for g in g_recon:
          g_str = decode_igraph_to_NAS201_str(g)
          total_cnt += 1
          if g_str is not None:
            if not g_str in nas_str_lst:
              valid_cnt += 1
              nas_str_lst.append(g_str)
              if valid_cnt == 1000:
                flag = True
                break
        x_batch, dname_batch = [], []
      if flag:
        break
  elapsed=time.time()-start

  with open(os.path.join(args.save_dir, 'samp_arch', f'{dname}', f"infer-{epoch}-{task_num}.txt"), 'w') as infer_file:
    print(f'{epoch} {dname} {elapsed}')
    infer_file.write(f'total cnt {total_cnt} valid cnt {valid_cnt}\n')
    for i, nas_str in enumerate(nas_str_lst):
      infer_file.write(f"{nas_str}\n") 
  return elapsed


def infer(dataset, epoch, dname=None):
  model.eval()
  x_batch, dname_batch = [], []
  nas_str_lst, gidx = [], []
  total_cnt, valid_cnt = 0, 0
  flag = False

  # task_num = dataset.task[tidx]
  dpath = os.path.join(args.save_dir, 'samp_arch', f'{dname}')
  if not os.path.exists(dpath):
    os.makedirs(dpath)
    print(f'create {dpath}')

  start=time.time()
  with torch.no_grad():
    for _ in range(1000000):
      x_batch.append(dataset[0])
      dname_batch.append(dname)
      if len(x_batch) == 100:
        x_batch = torch.stack(x_batch).cuda()
        mu, logvar = model.set_encode(x_batch, dataset.max_img, dname_batch)
        z = model.reparameterize(mu, logvar)
        g_recon = model.graph_decode(z)
        for g in g_recon:
          g_str = decode_igraph_to_NAS201_str(g)
          total_cnt += 1
          if g_str is not None:
            if not g_str in nas_str_lst:
              valid_cnt += 1
              nas_str_lst.append(g_str)
              if valid_cnt == 1000:
                flag = True
                break
        x_batch, dname_batch = [], []
      if flag:
        break
  elapsed=time.time()-start
  with open(os.path.join(args.save_dir, 'samp_arch', f'{dname}', f"infer-{epoch}.txt"), 'w') as infer_file:
    print(f'{epoch} {dname} {elapsed}')
    infer_file.write(f'total cnt {total_cnt} valid cnt {valid_cnt}\n')
    for i, nas_str in enumerate(nas_str_lst):
      infer_file.write(f"{nas_str}\n") 



# def infer(dataset, epoch, run, tasks=None, dname=None):
#   model.eval()
#   x_batch, dname_batch = [], []
#   nas_str_lst, gidx = [], []
#   total_cnt, valid_cnt = 0, 0
#   flag = False
#   start=time.time()
#   tasks = range(100000) if tasks is None else tasks
#   with torch.no_grad():
#     for tidx in tasks:
#       x_batch.append(dataset[tidx])
#       dname_batch.append(dname)
#       if len(x_batch) == args.batch_size or len(x_batch) == len(tasks):
#         x_batch = torch.stack(x_batch).cuda()
#         mu, logvar = model.set_encode(x_batch, dataset.max_img, dname_batch)
#         z = model.reparameterize(mu, logvar)
#         g_recon = model.graph_decode(z)
#         for g in g_recon:
#           g_str = decode_igraph_to_NAS201_str(g)
#           total_cnt += 1
#           if g_str is not None:
#             if not g_str in nas_str_lst:
#               valid_cnt += 1
#               nas_str_lst.append(g_str)
#               if valid_cnt == 1000:
#                 flag = True
#                 break
#         x_batch, dname_batch = [], []
#       if flag:
#         break
#   elapsed=time.time()-start

#   with open(os.path.join(args.save_dir, 'samp_arch', f"infer-{dname}-{epoch}-{run}.txt"), 'w') as infer_file:
#     print(f'{epoch} {dname}')
#     infer_file.write(f'total cnt {total_cnt} valid cnt {valid_cnt}\n')
#     for i, nas_str in enumerate(nas_str_lst):
#       infer_file.write(f"{nas_str}\n") 



def main():
  sttime = time.time()
  for epoch in range(1, args.epochs + 1):
    ep_sttime = time.time()
    loss, recon_loss, kld_loss = train(epoch)
    scheduler.step(loss)

    if 'time' in args.save_dir:
      if epoch % 50 == 0:
        with open(os.path.join(args.save_dir, 'train_loss.txt'), 'a') as f:
          msg = f"ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} time {time.time() - sttime:6.2f} "
          msg += f"loss {loss:.2f} recon {recon_loss / len(train_data)} kld {kld_loss / len(train_data)}\n"
          f.write(msg)
    else:
      with open(os.path.join(args.save_dir, 'train_loss.txt'), 'a') as f:
        msg = f"ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} time {time.time() - sttime:6.2f} "
        msg += f"loss {loss:.2f} recon {recon_loss / len(train_data)} kld {kld_loss / len(train_data)}\n"
        f.write(msg)
      
      if epoch % args.save_interval == 0:
        print("save current model...")
        torch.save(model.state_dict(), os.path.join(model_dir, f'model_checkpoint{epoch}.pth'))


if __name__ == '__main__':

  if args.rlv:
    for dname in ['cifar10']:
    # for dname in ['cifar10', 'cifar100', 'te', 'svhn', 'mnist', 'aircraft']:
      if dname == 'te':
        dataset = D2GDataset(max_img=20, dataset=args.dataset, mode=d)
        tasks = range(len(dataset))
      else:
        dataset = TestDataset(args.max_img, dname)
        tasks = None

    for epoch in [300]:
      load_module_state(model, os.path.join(model_dir, f'model_checkpoint{epoch}.pth'))
      for run in range(10):
        random_latent_vector_infer(dataset, epoch, run, tasks, dname)
  
  if not (args.test or args.rlv):
    main()


  if not args.pretrain:
    # sampling 1000 valid architectures
    # for dname in ['aircraft']:
    # for dname in ['te']:
    dname = args.dname
    if dname == 'te':
      dataset = MetaD2A(args, mode=dname)
      tasks = range(len(dataset))
      epoch = 300
      elapsed_lst = []
      load_module_state(model, os.path.join(model_dir, f'model_checkpoint{epoch}.pth'))
      for tidx in range(20):
        elapsed_lst.append(infer_MetaD2A(dataset, epoch, tidx, 'MetaD2A'))
      with open(os.path.join(args.save_dir, 'time.txt'), 'w') as tf:
        for elapsed in elapsed_lst:
          print(elapsed); tf.write(f'{elapsed}\n')
        print(np.mean(elapsed_lst)); tf.write(f'{np.mean(elapsed_lst)}\n')
    else:
      dataset = TestDataset(args.max_img, dname)
      tasks = None

      for epoch in [300]:
        load_module_state(model, os.path.join(model_dir, f'model_checkpoint{epoch}.pth'))
        infer(dataset, epoch, dname)


