import os
import sys
sys.path.insert(0, '../')
sys.path.insert(0, '../../')
sys.path.insert(0, '../../../')
import time
import numpy as np
import pandas as pd
import pickle
import copy

from tqdm import tqdm
#load torch 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset

import network_designer.utils as utils

from network_designer.design_space.blox.design_space import BloxDesignSpace
from network_designer.design_space.nasbench201.design_space import NB201LikeDesignSpace
from network_designer.design_space.ZenNAS.design_space import ZenNASDesignSpace

import math
#load model 
from network_designer.models.vgae import VGAE
from sklearn.model_selection import train_test_split

import argparse

class VAEReconstructed_Loss(object):
    def __init__(self, w_adj=1.0, w_ops=1.0, w_kl=0.005, loss_adj=nn.BCELoss, loss_ops=nn.BCELoss):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.w_kl = w_kl
        self.loss_ops = nn.BCEWithLogitsLoss().cuda()
        self.loss_adj = loss_adj()

    def __call__(self, inputs, targets, mu, logvar):
        adj_recon, ops_recon = inputs[0], inputs[1]
        adj, ops = targets[0], targets[1]
        b,n,l = ops_recon.size()
        loss_adj = self.loss_adj(adj_recon, adj)
        loss_ops = self.loss_ops(ops_recon, ops)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        KLD_1 = -0.5 / (ops.shape[0] * ops.shape[1]) * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 2))
        #KLD_2 = -0.5 / (ops.shape[0] * ops.shape[1]) * torch.mean(1+2*logvar - mu.pow(2) - logvar.exp().pow(2))
        return loss + (self.w_kl * KLD_1)
    
class VAEReconstructed_Loss_with_feature_chunks(object):
    def __init__(self, w_adj=1.0, w_ops=1.0, w_kl=0.005, loss_adj=nn.BCELoss, loss_ops=nn.BCEWithLogitsLoss, feature_chunks_size=None):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.w_kl = w_kl
        self.loss_ops = loss_ops()
        self.loss_adj = loss_adj()
        self.ops_losses = []
        self.feature_chunks_size = feature_chunks_size
        if self.feature_chunks_size:
            for i in range(len(self.feature_chunks_size)):
                if i > 1:
                    self.ops_losses.append(nn.BCEWithLogitsLoss())
                else:
                    self.ops_losses.append(nn.BCEWithLogitsLoss(pos_weight=torch.ones(11, self.feature_chunks_size[i])*self.feature_chunks_size[i]*0.5).cuda())

    def __call__(self, inputs, targets, mu, logvar):
        adj_recon = inputs[0]
        ops_feature_chunks = inputs[1]
        adj, ops = targets[0], targets[1]
        b, n , _ = ops.size()
        #basically we only want to consider connected node
        loss_adj = self.loss_adj(adj_recon, adj)
        loss_ops = 0.
        ops = torch.split(ops, split_size_or_sections=self.feature_chunks_size, dim=2)
        for i, (pred, tar) in enumerate(zip(ops_feature_chunks, ops)):
            loss_ops += self.ops_losses[i](pred, tar)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        KLD = -0.5 / (b*n) * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 2))
        return loss + (self.w_kl * KLD)
    
def train_autoencoder(model, trainloader, epoch_step, epoch, num_epochs, lr, optimizer, criterion, scala_dim=None):
    model.train()
    train_loss = 0


    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    print('|  Number of Trainable Parameters: ' + str(params))
    print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, lr))
    
    for batch_idx, inputs in enumerate(trainloader):
        optimizer.zero_grad()
        inputs = (inputs[0].cuda(), inputs[1].cuda())
        adj_recon, ops_recon, mu, logvar, z = model.forward_decoder(inputs)               # Forward Propagation
        z = z.flatten(start_dim=1)
        loss = criterion((adj_recon, ops_recon), (inputs[0], inputs[1]), mu, logvar)
        loss.backward()  
        nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()  
        try:
            loss.data[0]

        except IndexError:
            loss.data = torch.reshape(loss.data, (1,))
        train_loss += loss.data[0]
        
        sys.stdout.write('\r')
        sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f \t\t'
                         % (epoch, num_epochs, batch_idx+1,
                            (epoch_step)+1, loss.data[0]))
        sys.stdout.flush()
        
    return train_loss
        
def test_autoencoder(model, test_loder, test_set, epoch, criterion,design_space=None):
    loss = 0
    mean_mse_loss=0
    correct_ops_ave, mean_correct_adj_ave, mean_false_positive_adj_ave, correct_adj_ave = 0, 0, 0, 0
    feature_chunks_size = design_space.feature_chunks_size
    for batch_idx, inputs in enumerate(tqdm(test_loder)):
        torch.set_grad_enabled(False)
        model.eval()
        N, I, _ = inputs[1].size()
        inputs = (inputs[0].cuda(), inputs[1].cuda())
        adj_recon, ops_recon, mu, logvar, z= model.forward_decoder(inputs)

        loss = criterion((adj_recon, ops_recon), (inputs[0], inputs[1]), mu, logvar)

        ops_recon = torch.cat(ops_recon, dim=-1)
        ops_recon = torch.sigmoid(ops_recon)
        if batch_idx == 0:
            target_arch = design_space.transfer_graph_to_str(inputs[0].cpu().numpy().squeeze(0), inputs[1].int().cpu().numpy().squeeze(0), print_debug=True)
            print(target_arch)
            try:
                pred_arch = design_space.transfer_graph_to_str(adj_recon.cpu().numpy().squeeze(0), (ops_recon>0.5).int().cpu().numpy().squeeze(0), print_debug=True)
                print(pred_arch)
            except:
               print('invalid graph')
    
        torch.set_grad_enabled(True)
        loss += loss.item()


        correct_ops, mean_correct_adj, mean_false_positive_adj, correct_adj = utils.get_accuracy_zenNAS(((ops_recon>0.5).int(), adj_recon), (inputs[1], inputs[0]),N, I,feature_chunks_size, design_space)
        correct_ops_ave += correct_ops
        mean_correct_adj_ave += mean_correct_adj 
        mean_false_positive_adj_ave += mean_false_positive_adj 
        correct_adj_ave += correct_adj 
    
    avg_loss = loss / len(test_set)
    mean_mse_loss = mean_mse_loss/len(test_set)
    correct_ops_ave = round((correct_ops_ave / len(test_set)), 4)
    mean_correct_adj_ave = round((mean_correct_adj_ave.item() / len(test_set)) , 4)
    mean_false_positive_adj_ave = round((mean_false_positive_adj_ave.item() / len(test_set)) , 4)
    correct_adj_ave = round((correct_adj_ave / len(test_set)), 4)
    
    print(f'/n')
    print(f'Average loss of test set {epoch}: {avg_loss} | Op Acc:{correct_ops_ave} | Mean Correct Adj:{mean_correct_adj_ave} | Mean FP Adj: {mean_false_positive_adj_ave} | Adj Acc:{correct_adj_ave}')
    
    return correct_ops_ave, correct_adj_ave, avg_loss

def autoencoder_train(model, trainloader, traindataset, testloader, 
                      testdataset, save=True, save_label='', save_path='',
                      batch_size=64, epoch=250, lr=0.00035, design_space=None):
    print('|  Initial Learning Rate: ' + str(1e-3))
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
    highest_acc = None
    lowest_loss =None
    if design_space:
        criterion = VAEReconstructed_Loss_with_feature_chunks(w_adj=1.0, w_ops=1.0, feature_chunks_size=design_space.feature_chunks_size)
    else:
        criterion = VAEReconstructed_Loss(w_adj=1.0, w_ops=1.0)
    
    train_iter = 5500000

    epoch_step = (len(traindataset)//64)+1
    epoch_num = train_iter//epoch_step
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epoch_num), eta_min=0)
    es = utils.EarlyStopping(mode='min', patience=math.floor(epoch_num*0.1) if math.floor(epoch_num*0.1)>35 else 35)
    for epoch in range(1, 1+epoch_num):
        lr = scheduler.get_last_lr()[0]
        #print(lr)
        model.train()
        start_time = time.time()
        train_loss = train_autoencoder(model, trainloader, epoch_step, epoch, epoch_num, lr, optimizer, criterion=criterion)
        
        correct_ops_ave, correct_adj_ave, avg_loss = test_autoencoder(model, testloader, testdataset, epoch, criterion=criterion, design_space=design_space)
        scheduler.step()
        
        if epoch > 50:
            if es.step(avg_loss.cpu()):
                print('Early stopping criterion is met, stop training now.')
                print(f'Best result yet {highest_acc}.. VGAE model')
                break

        if highest_acc is None or correct_ops_ave > highest_acc[0]:
            highest_acc = (correct_ops_ave, correct_adj_ave)
            best_vgae_weight = model.state_dict()
            if save:
                torch.save(best_vgae_weight, os.path.join(save_path, 'vgae_{}'.format(save_label)))
                print(f'Best result yet {highest_acc}.. VGAE model saved.')
            
    print(f'Best result yet {highest_acc}.. VGAE model')
  
def train_pipeline(trainset, testset, train_portion, cluster_eps,
                   num_features=6, num_layers=4, num_hidden=256, num_latent=32, 
                   batch_size=64, epoch=250, lr=0.00035, save='', design_space=None):
    
    # num_train = len(trainset)
    # split = int(np.floor(train_portion * num_train))
    # current_trainset = copy.deepcopy(trainset[:split])
    current_testset = copy.deepcopy(testset)
    
    if args.design_space == 'nasbench201':
        design_space = NB201LikeDesignSpace(args=args)
        feature_chunks_size=None
    elif args.design_space == 'blox':
        design_space = BloxDesignSpace(args=args)
        feature_chunks_size=None
    else:
        design_space = ZenNASDesignSpace(args=args)
        feature_chunks_size = design_space.feature_chunks_size
        
    train_dataset = utils.OnlineGraphDataset(design_space)

    test_dataset = utils.GraphDataset(adj_matrix=current_testset['adj_matrix'].values, 
                                    ops_feature=current_testset['ops_features'].values)
    
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
    
                                            
    test_loader=torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
    
                                            
    gvae = VGAE(num_features=num_features, num_layers=num_layers, num_hidden=num_hidden, num_latent=num_latent, feature_chunks_size=feature_chunks_size)
    gvae = gvae.cuda()

    autoencoder_train(gvae, train_loader, train_dataset, 
                      test_loader, test_dataset, save_label=str(train_portion),
                      batch_size=batch_size, epoch=epoch, lr=lr,save_path=save,design_space=design_space)

parser = argparse.ArgumentParser("vgae")

#devices
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--seed', type=int, default=0, help='random seed')
#dataset
parser.add_argument('--train_portion', type=float, default=1.0, help='dataset portion for training VGAE')
parser.add_argument('--dataset_path', type=str, default="../../exp/benchmarks/nasbench201/sampled_graph_with_vec.pkl", help='dataset path that store all the graph and its zc-vector')
parser.add_argument('--batch_size', type=int, default=64, help='batch size of train loader')

#model parameters
parser.add_argument('--num_features', type=int, default=6, help='number of ops in graph nodes')
parser.add_argument('--num_layers', type=int, default=4, help='number of layers in VAE model')
parser.add_argument('--num_hidden', type=int, default=256, help='len of itermediate feature in VAE')
parser.add_argument('--num_latent', type=int, default=32, help='len of output features from encoder')
parser.add_argument('--ops_activation', type=str, default='sigmoid')
parser.add_argument('--ops_loss', type=str, default='BECLoss')
#train parameters
parser.add_argument('--epoch', type=int, default=250)
parser.add_argument('--lr', type=float ,default=0.00035)

#exp 
parser.add_argument('--exp', type=str, default='Test')

#clustering 
parser.add_argument('--cluster_eps', type=float ,default=1.0)

parser.add_argument('--design_space', type=str, default='ImageNet800M', help='use scala ops features')
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset name')
parser.add_argument('--ss_path', default='', type=str, help='path to ss design space ')

args = parser.parse_args()

if __name__ == '__main__':
    cudnn.benchmark = True
    cudnn.enabled = True
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    exp_path = "../../experiments/{}/step_1/".format(args.exp)
    os.makedirs(exp_path, exist_ok=True)
    
    #set devices 
    if not torch.cuda.is_available():
        device = torch.device("cpu")  
    else:
        device = utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
        torch.cuda.set_device(device)
    print("Using device", device)

    f = open(args.dataset_path, 'rb')
    data = pickle.load(f)
    dataset = pd.DataFrame(data)

    trainset, testset = train_test_split(dataset, test_size=0.1)
    
    
    num_train = len(trainset)
    indices = list(range(num_train))
    train_pipeline(trainset, testset, args.train_portion, cluster_eps=args.cluster_eps,
                   num_features=args.num_features, num_layers=args.num_layers,
                   num_hidden=args.num_hidden, num_latent=args.num_latent,
                   batch_size=args.batch_size, epoch=args.epoch, lr=args.lr, save=exp_path, design_space=args.design_space)