import pdb
import pickle
import sys
import os
import os.path
import collections
import torch
import argparse
import pandas as pd
from tqdm import tqdm
import itertools
from scipy.spatial.distance import pdist
import matplotlib.pyplot as plt
from sparse_gp import SparseGP
import scipy.stats as sps
import numpy as np
import scipy.io
from scipy.io import loadmat
from scipy.stats import pearsonr
sys.path.append('%s/../software/enas' % os.path.dirname(os.path.realpath(__file__))) 
sys.path.append('%s/..' % os.path.dirname(os.path.realpath(__file__))) 
sys.path.insert(0, '../')
from models_topo import *
from utils import *
from shutil import copy

'''Experiment settings'''
parser = argparse.ArgumentParser(description='Bayesian optimization experiments.')
# must specify
parser.add_argument('--data-name', default='circuit101', help='graph dataset name')
parser.add_argument('--save-appendix', default='', 
                    help='what is appended to data-name as save-name for results')
parser.add_argument('--checkpoint', type=int, default=300, 
                    help="load which epoch's model checkpoint")
parser.add_argument('--res-dir', default='res/', 
                    help='where to save the Bayesian optimization results')
parser.add_argument('--reprocess', action='store_true', default=False,
                    help='if True, reprocess data instead of using prestored .pkl data')
# BO settings
parser.add_argument('--predictor', action='store_true', default=False,
                    help='if True, use the performance predictor instead of SGP')
parser.add_argument('--grad-ascent', action='store_true', default=False,
                    help='if True and predictor=True, perform gradient-ascent with predictor')
parser.add_argument('--BO-rounds', type=int, default=10, 
                    help="how many rounds of BO to perform")
parser.add_argument('--BO-batch-size', type=int, default=50, 
                    help="how many data points to select in each BO round")
parser.add_argument('--sample-dist', default='uniform', 
                    help='from which distrbiution to sample random points in the latent \
                    space as candidates to select; uniform or normal')
parser.add_argument('--random-baseline', action='store_true', default=False,
                    help='whether to include a baseline that randomly selects points \
                    to compare with Bayesian optimization')
parser.add_argument('--random-as-train', action='store_true', default=False,
                    help='if true, no longer use original train data to initialize SGP \
                    but randomly generates 1000 initial points as train data')
parser.add_argument('--random-as-test', action='store_true', default=False,
                    help='if true, randomly generates 100 points from the latent space \
                    as the additional testing data')
parser.add_argument('--vis-2d', action='store_true', default=False,
                    help='do visualization experiments on 2D space')
parser.add_argument('--emb_dim', type=int, default=128, metavar='N', help='embdedding dimension')
parser.add_argument('--v1', type=int, default=1, help='embdedding dimension')


# can be inferred from the cmd_input.txt file, no need to specify
parser.add_argument('--data-type', default='ENAS',
                    help='ENAS: ENAS-format CNN structures; BN: Bayesian networks')
parser.add_argument('--model', default='DVAE', help='model to use: DVAE, SVAE, \
                    DVAE_fast, DVAE_BN, SVAE_oneshot, DVAE_GCN')
parser.add_argument('--hs', type=int, default=501, metavar='N',
                    help='hidden size of GRUs')
parser.add_argument('--nz', type=int, default=56, metavar='N',
                    help='number of dimensions of latent vectors z')
parser.add_argument('--bidirectional', action='store_true', default=False,
                    help='whether to use bidirectional encoding')

parser.add_argument('--nvt', type=int, default=26, help='number of different node (subgraph) types')
parser.add_argument('--max_n', type=int, default=8, help='number of different node (subgraph) types')
parser.add_argument('--subg_nvt', type=int, default=10, help='number of subgraph nodes')
parser.add_argument('--subn_nvt', type=int, default=10, help='number of subgraph feat')
parser.add_argument('--ng', type=int, default=10000, help='number of circuits in the dataset')
parser.add_argument('--node_feat_type', type=str, default='discrete', help='node feature type: discrete or continuous')

parser.add_argument('--cuda_id', type=int, default=1, metavar='N',
                    help='id of GPU')
parser.add_argument('--infer-batch-size', type=int, default=128, metavar='N',
                    help='batch size during inference')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--all-gpus', action='store_true', default=False,
                    help='use all available GPUs')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    device = torch.device("cuda:{}".format(args.cuda_id))
else:
    device = torch.device("cpu")
np.random.seed(args.seed)
random.seed(args.seed)

data_name = args.data_name
save_appendix = args.save_appendix
#data_dir1 = '../results/{}_{}/'.format(data_name, save_appendix)  # data and model folder
checkpoint = args.checkpoint
#res_dir = args.res_dir
data_type = args.data_type
model_name = args.model
bidir = args.bidirectional
vis_2d = args.vis_2d
hs = args.hs
nz =args.nz
max_n = args.max_n
if args.model.startswith('CVAE'):
     nvt = 26
     START_TYPE = 0
     END_TYPE = 1
else:
     nvt = 10
     START_TYPE = 8
     END_TYPE = 9

BO_rounds = args.BO_rounds
batch_size = args.BO_batch_size
sample_dist = args.sample_dist
random_baseline = args.random_baseline 
random_as_train = args.random_as_train
random_as_test = args.random_as_test

args.file_dir = os.path.dirname(os.path.realpath('__file__'))
args.res_dir = os.path.join(args.file_dir, 'results/{}{}'.format(args.data_name,args.save_appendix))
args.data_dir = os.path.join(args.file_dir, 'data/{}'.format(args.data_name))
if not os.path.exists(args.data_dir):
    os.makedirs(args.data_dir) 

pkl_name = os.path.join(args.data_dir, args.data_name + '.pkl')

# check whether to load pre-stored pickle data
if os.path.isfile(pkl_name) and not args.reprocess:
    with open(pkl_name, 'rb') as f:
        train_dataset, test_dataset = pickle.load(f)
# otherwise process the raw data and save to .pkl
else:
    txt_name = os.path.join(args.data_dir, args.data_name + '.txt')
    train_dataset, test_dataset = train_test_generator_topo_simple(ng=args.ng, name=txt_name)
    with open(pkl_name, 'wb') as f:
        pickle.dump((train_dataset, test_dataset), f)

# determine data formats according to models, DVAE: igraph, SVAE: string (as tensors)
if args.model.startswith('CVAE'):
    train_data = [train_dataset[i][0] for i in range(len(train_dataset))]
    test_data = [test_dataset[i][0] for i in range(len(test_dataset))]
else:
    train_data = [train_dataset[i][1] for i in range(len(train_dataset))]
    test_data = [test_dataset[i][1] for i in range(len(test_dataset))]

def is_same_DAG_filter(g0, g1):
    # note that it does not check isomorphism
    if g0.vcount() != g1.vcount():
        return False
    for vi in range(g0.vcount()):
        if g0.vs[vi]['type'] != g1.vs[vi]['type']:
            return False
        if g0.vs[vi]['vid'] != g1.vs[vi]['vid']:
            return False
        if set(g0.neighbors(vi, 'in')) != set(g1.neighbors(vi, 'in')):
            return False
    return True

def BO_filter(gdata):
    pbar = tqdm(range(len(gdata)))
    g_list = []
    tn = 0
    topo_ind_map = {} # :g=topo_num [ind1, ind2] 
    for i in pbar:
        g0 = gdata[i]
        if len(g_list) > 0:
            exi = False
            pos = 0
            for k in range(len(g_list)):
                g1 = g_list[k]
                if is_same_DAG_filter(g0, g1):
                    exi = True
                    pos = k
            if not exi:
                tn += 1
                g_list.append(g0)
                topo_ind_map[tn] = [i]
            else:
                topo_ind_map[pos].append(i)
        else:
            g_list.append(g0)
            if tn in topo_ind_map.keys():
                topo_ind_map[tn].append(i)
            else:
                topo_ind_map[tn] = [i]
            tn += 1
    return topo_ind_map, g_list, tn
    

def performance_readout1(num_graphs, gdata, file_dir='circuit', name = 'ckt_simulation_summary_10000.txt'):
    num_graphs = 10000
    pbar = tqdm(range(num_graphs))
    gain = []
    bw = []
    pm = []
    fom = []
    valid = []
    #with open('ckt_simulation_summary_10000.txt', 'r') as f:
    file_name = os.path.join(file_dir, name)
    with open(file_name, 'r') as f:
        for i in pbar:
            row = f.readline().strip().split()
            if not row[1] == 'Simulation':
                g = float(row[1])/100.0
                p = float(row[2])/-90.0
                b = float(row[3])/1e9
                gain.append(g)
                pm.append(p)
                bw.append(b)
                fo = 1.2 * np.abs(g) + 1.6 * p + 10 * b
                fom.append(fo)
                valid.append(1)
            else:
                gain.append(0)
                pm.append(0)
                bw.append(0)
                fom.append(0)
                valid.append(0)
    gain = np.array(gain) - np.min(gain) + 0.00001
    pm = np.array(pm) - np.min(pm) + 0.00001
    bw = np.array(bw) - np.min(bw) + 0.00001
    fom = np.array(fom) - np.min(fom) + 0.00001
    
    topo_ind_map, g_list, tn = BO_filter(gdata) 
    valid2 = [0] * 10000
    for key in topo_ind_map.keys():
        inds = topo_ind_map[key]
        if len(inds) == 1:
            valid2[inds[0]] = 1
        else:
            vals = [fom[i] for i in inds]
            for j in range(len(vals)):
                if vals[j] == np.max(vals):
                    valid2[inds[j]] = True
    perform = {'valid':valid, 'valid2': valid2 ,'gain':gain, 'pm':pm, 'bw':bw, 'fom':fom}
    perform_df = pd.DataFrame(perform)
    out_name = os.path.join(file_dir, 'perform.csv')
    perform_df.to_csv(out_name)
    return perform_df



def extract_latent(data, perform_df):
    model.eval()
    Z = []
    Y = []
    Gain = []
    BW = []
    PM = []
    g_batch = []
    for i, g  in enumerate(tqdm(data)):
        if args.model.startswith('SVAE'):
            g_ = g.to(device)
        else:
          # copy igraph
            # otherwise original igraphs will save the H states and consume more GPU memory
            g_ = g.copy()  
        if perform_df['valid'][i] == 1 and perform_df['valid2'][i] == 1: 
            g_batch.append(g_)
        if len(g_batch) == args.infer_batch_size or i == len(data) - 1:

            g_batch = model._collate_fn(g_batch)
            mu, _ = model.encode(g_batch)
            mu = mu.cpu().detach().numpy()
            Z.append(mu)
            g_batch = []
        if perform_df['valid'][i] == 1 and perform_df['valid2'][i] == 1: 
            y = perform_df['fom'][i]
            gain = perform_df['gain'][i]
            bw = perform_df['bw'][i]
            pm = perform_df['pm'][i]
            Y.append(y)
            Gain.append(gain)
            BW.append(bw)
            PM.append(pm)
    return np.concatenate(Z, 0), np.array(Y), np.array(Gain), np.array(BW), np.array(PM)


'''Extract latent representations Z'''
def save_latent_representations(epoch, perform_df):
    Z_train, Y_train, Gain_train, BW_train, PM_train = extract_latent(train_data, perform_df)
    Z_test, Y_test, Gain_test, BW_test, PM_test = extract_latent(test_data, perform_df)
    latent_pkl_name = os.path.join(args.res_dir, args.data_name +
                                   '_latent_epoch{}.pkl'.format(epoch))
    latent_mat_name = os.path.join(args.res_dir, args.data_name + 
                                   '_latent_epoch{}.mat'.format(epoch))
    with open(latent_pkl_name, 'wb') as f:
        pickle.dump((Z_train, Y_train, Z_test, Y_test), f)
    print('Saved latent representations to ' + latent_pkl_name)
    scipy.io.savemat(latent_mat_name, 
                     mdict={
                         'Z_train': Z_train, 
                         'Z_test': Z_test, 
                         'Y_train': Y_train, 
                         'Y_test': Y_test,
                         'Gain_train': Gain_train,
                         'Gain_test': Gain_test,
                         'BW_train':BW_train,
                         'BW_test':BW_test,
                         'PM_train':PM_train,
                         'PM_test':PM_test
                         }
                     )



# other BO hyperparameters
lr = 0.0005  # the learning rate to train the SGP model
max_iter = 100  # how many iterations to optimize the SGP each time

#data = loadmat(data_dir + '{}_latent_epoch{}.mat'.format(data_name, checkpoint))  # load train/test data
perfor_dir = os.path.join(args.data_dir, 'perf.csv')

if not os.path.exists(perfor_dir):
    perform_df = performance_readout1(args.ng, train_data+test_data, file_dir=args.data_dir)
else:
    perform_df = pd.read_csv(perfor_dir)


for rand_idx in range(1,6):

     save_dir = os.path.join(args.res_dir,'sgp_reg_{}_{}/'.format(save_appendix, rand_idx))
     # set seed
     random_seed = rand_idx
     torch.manual_seed(random_seed)
     torch.cuda.manual_seed(random_seed)
     np.random.seed(random_seed)

     # load the decoder
     #model = eval(model_name)(
     #        max_n=max_n, 
     #        nvt=nvt, 
     #        START_TYPE=START_TYPE, 
     #        END_TYPE=END_TYPE, 
     #        hs=hs, 
     #        nz=nz, 
     #        bidirectional=bidir, 
     #       )
     if args.model.startswith('CVAE'):
         if args.v1 == 1:
             model = CVAE_topo1(
                 max_n = args.max_n, 
                 nvt = 26, 
                 subn_nvt = args.subn_nvt,
                 START_TYPE = 0, 
                 END_TYPE = 1, 
                 emb_dim = args.emb_dim, 
                 hs=args.hs, 
                 nz=args.nz
                 )
         else:
             model = CVAE_topo(
                 max_n = args.max_n, 
                 nvt = 26, 
                 subn_nvt = args.subn_nvt,
                 START_TYPE = 0, 
                 END_TYPE = 1, 
                 emb_dim = args.emb_dim, 
                 hs=args.hs, 
                 nz=args.nz
                 )
     else: # max_n changes
         model = DVAE_topo(
             max_n = 24, 
             max_pos=args.max_n,
             nvt = 10, 
             feat_nvt = args.subn_nvt, 
             START_TYPE = 8, 
             END_TYPE = 9,  
             hs=args.hs, 
             nz=args.nz
         )
    
     model.to(device)
     load_module_state(model, os.path.join(args.res_dir, 'model_checkpoint{}.pth'.format(checkpoint)), device=device)
     X_train, Y_train, Gain_train, BW_train, PM_train = extract_latent(train_data, perform_df)
     #X_test, Y_test, Gain_test, BW_test, PM_test = extract_latent(test_data, perform_df)
     X_test, Y_test, Gain_test, BW_test, PM_test = X_train[:1000], Y_train[:1000], Gain_train[:1000], BW_train[:1000], PM_train[:1000]

     #X_train = data['Z_train']
     #y_train = -data['Y_train'].reshape((-1,1))
     y_train = -Y_train.reshape((-1,1))
     gain_train = -Gain_train.reshape((-1,1))
     bw_train = -BW_train.reshape((-1,1))
     pm_train = -PM_train.reshape((-1,1))
    

     mean_y_train, std_y_train = np.mean(y_train), np.std(y_train)
     mean_gain_train, std_gain_train = np.mean(gain_train), np.std(gain_train)
     mean_bw_train, std_bw_train = np.mean(bw_train), np.std(bw_train)
     mean_pm_train, std_pm_train = np.mean(pm_train), np.std(pm_train)

     #print('Mean, std of y_train is ', mean_y_train, std_y_train)
     y_train = (y_train - mean_y_train) / std_y_train
     gain_train = (gain_train - mean_gain_train) / std_gain_train
     bw_train = (bw_train - mean_bw_train) / std_bw_train
     pm_train = (pm_train - mean_pm_train) / std_pm_train


     #X_test = data['Z_test']
     #y_test = -data['Y_test'].reshape((-1,1))
     y_test = -Y_test.reshape((-1,1))
     y_test = (y_test - mean_y_train) / std_y_train
     gain_test = -Gain_test.reshape((-1,1))
     gain_test = (gain_test - mean_gain_train) / std_gain_train
     bw_test = -BW_test.reshape((-1,1))
     bw_test = (bw_test - mean_bw_train) / std_bw_train
     pm_test = -PM_test.reshape((-1,1))
     pm_test = (pm_test - mean_pm_train) / std_pm_train
    
     #best_train_score = min(y_train)
     #save_object((mean_y_train, std_y_train), "{}mean_std_y_train.dat".format(save_dir))

     '''Bayesian optimiation begins here'''
     iteration = 0
     best_score = 1e15
     best_arc = None
     best_random_score = 1e15
     best_random_arc = None
     print("Average pairwise distance between train points = {}".format(np.mean(pdist(X_train))))
     print("Average pairwise distance between test points = {}".format(np.mean(pdist(X_test))))

     if os.path.exists(save_dir + 'Test_RMSE_ll.txt'):
         os.remove(save_dir + 'Test_RMSE_ll.txt')
     #if os.path.exists(save_dir + 'best_arc_scores.txt'):
     #    os.remove(save_dir + 'best_arc_scores.txt')

     M = 500
     sgp_fom = SparseGP(X_train, 0 * X_train, y_train, M)
     sgp_fom.train_via_ADAM(X_train, 0 * X_train, y_train, X_test, X_test * 0,  y_test, minibatch_size = 2 * M, max_iterations = max_iter, learning_rate = lr)
     pred_fom, uncert_fom = sgp_fom.predict(X_test, 0 * X_test)
     error_fom= np.sqrt(np.mean((pred_fom - y_test)**2))
     testll_fom = np.mean(sps.norm.logpdf(pred_fom - y_test, scale = np.sqrt(uncert_fom)))
     pearson_fom = float(pearsonr(pred_fom.reshape(-1,), y_test.reshape(-1,))[0])  
     print('Fom RMSE: ', error_fom)
     print('Fom Pearson r: ', pearson_fom)

     sgp_gain = SparseGP(X_train, 0 * X_train, gain_train, M)
     sgp_gain.train_via_ADAM(X_train, 0 * X_train, gain_train, X_test, X_test * 0,  gain_test, minibatch_size = 2 * M, max_iterations = max_iter, learning_rate = lr)
     pred_gain, uncert_gain = sgp_gain.predict(X_test, 0 * X_test)
     error_gain= np.sqrt(np.mean((pred_gain - gain_test)**2))
     testll_gain = np.mean(sps.norm.logpdf(pred_gain - gain_test, scale = np.sqrt(uncert_gain)))
     pearson_gain = float(pearsonr(pred_gain.reshape(-1,), gain_test.reshape(-1,))[0])  
     print('Gain RMSE: ', error_gain)
     print('Gain Pearson r: ', pearson_gain)

     sgp_bw = SparseGP(X_train, 0 * X_train, bw_train, M)
     sgp_bw.train_via_ADAM(X_train, 0 * X_train, bw_train, X_test, X_test * 0,  bw_test, minibatch_size = 2 * M, max_iterations = max_iter, learning_rate = lr)
     pred_bw, uncert_bw = sgp_bw.predict(X_test, 0 * X_test)
     error_bw= np.sqrt(np.mean((pred_bw - bw_test)**2))
     testll_bw = np.mean(sps.norm.logpdf(pred_bw - bw_test, scale = np.sqrt(uncert_bw)))
     pearson_bw = float(pearsonr(pred_bw.reshape(-1,), bw_test.reshape(-1,))[0])  
     print('BW RMSE: ', error_bw)
     print('BW Pearson r: ', pearson_bw)

     sgp_pm = SparseGP(X_train, 0 * X_train, pm_train, M)
     sgp_pm.train_via_ADAM(X_train, 0 * X_train, pm_train, X_test, X_test * 0,  pm_test, minibatch_size = 2 * M, max_iterations = max_iter, learning_rate = lr)
     pred_pm, uncert_pm = sgp_pm.predict(X_test, 0 * X_test)
     error_pm= np.sqrt(np.mean((pred_pm - pm_test)**2))
     testll_pm = np.mean(sps.norm.logpdf(pred_pm - pm_test, scale = np.sqrt(uncert_pm)))
     pearson_pm = float(pearsonr(pred_pm.reshape(-1,), pm_test.reshape(-1,))[0])  
     print('PM RMSE: ', error_pm)
     print('PM Pearson r: ', pearson_pm)

     with open(save_dir + 'Test_RMSE_ll.txt', 'a') as test_file:
          test_file.write('Fom RMSE: {:.4f}, ll: {:.4f},  Pearson r: {:.4f}\n'.format(error_fom, testll_fom, pearson_fom))
          test_file.write('Gain RMSE: {:.4f}, ll: {:.4f}, Pearson r: {:.4f}\n'.format(error_gain, testll_gain, pearson_gain))
          test_file.write('BW RMSE: {:.4f}, ll: {:.4f}, Pearson r: {:.4f}\n'.format(error_bw, testll_bw, pearson_bw))
          test_file.write('PM RMSE: {:.4f}, ll: {:.4f}, Pearson r: {:.4f}\n'.format(error_pm, testll_gain_pm, pearson_pm))
    
























