import os
import sys
sys.path.insert(0, '../')
sys.path.insert(0, '../../')
sys.path.insert(0, '../../../')
import time
import numpy as np
import pandas as pd
import pickle
import copy

from tqdm import tqdm
#load torch 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset

import network_designer.utils as utils
from pytorch_metric_learning import losses, miners, testers
from network_designer.design_space.blox.design_space import BloxDesignSpace
from network_designer.design_space.nasbench201.design_space import NB201LikeDesignSpace
from network_designer.design_space.ZenNAS.design_space import ZenNASDesignSpace

import math
#load model 
from network_designer.models.vae import VAE
from sklearn.model_selection import train_test_split

import argparse
def custom_collate(batch):
    adj_data, ops_data, labels = [], [], []
    
    for (adj, ops), label in batch:
        adj_data.append(adj)
        ops_data.append(ops)
        labels.append(label)

    adj_data = torch.stack(adj_data)
    ops_data = torch.stack(ops_data)

    if any(label is None for label in labels):
        labels = None
    else:
        labels = torch.stack(labels)

    return (adj_data, ops_data), labels

class VAEReconstructed_Loss(object):
    def __init__(self, w_adj=1.0, w_ops=1.0, w_kl=0.05, loss_adj=nn.BCEWithLogitsLoss, loss_ops=nn.BCEWithLogitsLoss, num_features=6, n_nodes=9):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.w_kl = w_kl
        self.loss_ops = loss_ops()
        self.loss_adj = loss_adj().cuda()

    def __call__(self, inputs, targets, mu, logvar):
        adj_recon, ops_recon = inputs[0], inputs[1]
        adj, ops = targets[0], targets[1]
        b,n,l = ops_recon.size()
        # print(adj_recon.size())
        # print(adj.size())
        loss_adj = self.loss_adj(adj_recon, adj)
        loss_ops = self.loss_ops(ops_recon, ops)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        #KLD = -0.5 / (ops.shape[0] * ops.shape[1]) * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 2))
        KLD =  -0.5 * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), dim=1))

        #KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        #KLD /= n*l  # normalize

        return loss + (self.w_kl * KLD)
    
def train_autoencoder(model, trainloader, epoch_step, epoch, num_epochs, lr, optimizer, criterion, scala_dim=None, loss_func=None, miner=None):
    model.train()
    train_loss = 0

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    print('|  Number of Trainable Parameters: ' + str(params))
    print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, lr))
    
    for batch_idx, (adj, ops) in enumerate(trainloader):
        optimizer.zero_grad()
        adj = adj.cuda()
        ops = ops.cuda()
        
        adj_size = adj.size()
        ops_size = ops.size()
        # print(adj.size())
        # print(ops.size())
        adj_recon, ops_recon, mu, logvar, _ = model.forward_decoder((adj.flatten(start_dim=1), ops.flatten(start_dim=1)))
        recloss = criterion((adj_recon, ops_recon), (adj, ops), mu, logvar)   
        loss = recloss 
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step() 
        
        try:
            loss.data[0]
            recloss.data[0]
        except IndexError:
            loss.data = torch.reshape(loss.data, (1,))
            recloss.data = torch.reshape(recloss.data, (1,))
            
        train_loss += loss.data[0]
        
        sys.stdout.write('\r')
        sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f \t\tRecLoss: %.4f'
                         % (epoch, num_epochs, batch_idx+1,
                            (epoch_step)+1, loss.data[0], recloss.data[0]))
        sys.stdout.flush()
        
        
    return train_loss
        
def test_autoencoder(model, test_loder, test_set, epoch, criterion):
    loss = 0.
    model.eval()
    correct_ops_ave, mean_correct_adj_ave, mean_false_positive_adj_ave, correct_adj_ave = 0, 0, 0, 0
    for batch_idx, inputs in enumerate(tqdm(test_loder)):
        torch.set_grad_enabled(False)
        N, I, _ = inputs[1].size()
        inputs = (inputs[0].cuda(), inputs[1].cuda())
        adj_recon, ops_recon, mu, logvar, _ = model.forward_decoder(inputs)
        rec_loss = criterion((adj_recon, ops_recon), (inputs[0], inputs[1]), mu, logvar)
        loss += rec_loss.item()
        if batch_idx == 600:
           print( (torch.nn.Sigmoid()(adj_recon) > 0.5).int())
           print(inputs[0].int())
           
           print( (torch.nn.Sigmoid()(ops_recon) > 0.5).int())
           print(inputs[1].int())
        torch.set_grad_enabled(True)
        

        correct_ops, mean_correct_adj, mean_false_positive_adj, correct_adj = utils.get_accuracy((ops_recon, adj_recon), (inputs[1], inputs[0]))
        
        correct_ops_ave += correct_ops
        mean_correct_adj_ave += mean_correct_adj 
        mean_false_positive_adj_ave += mean_false_positive_adj 
        correct_adj_ave += correct_adj 
    
    avg_loss = loss / len(test_set)
    correct_ops_ave = round((correct_ops_ave / len(test_set)), 4)
    mean_correct_adj_ave = round((mean_correct_adj_ave.item() / len(test_set)) , 4)
    mean_false_positive_adj_ave = round((mean_false_positive_adj_ave.item() / len(test_set)) , 4)
    correct_adj_ave = round((correct_adj_ave / len(test_set)), 4)
    
    print(f'/n')
    print(f'Average loss of test set {epoch}: {avg_loss} | Op Acc:{correct_ops_ave} | Mean Correct Adj:{mean_correct_adj_ave} | Mean FP Adj: {mean_false_positive_adj_ave} | Adj Acc:{correct_adj_ave}')
    
    return correct_ops_ave, correct_adj_ave, avg_loss

def autoencoder_train(model, trainloader, traindataset, testloader, 
                      testdataset, save=True, save_label='', save_path='',
                      batch_size=64, train_budget=250, lr=0.00035, num_features=0, n_nodes=0):
    print('|  Initial Learning Rate: ' + str(lr))
    #optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    highest_acc = None
    lowest_loss =None
    criterion = VAEReconstructed_Loss(w_adj=1.0, w_ops=1.0, num_features=num_features, n_nodes=n_nodes)
    # # Set the mining function
    # miner = miners.MultiSimilarityMiner(epsilon=0.1)
    train_iter = train_budget

    epoch_step = (len(traindataset)//64)+1
    epoch_num = train_iter//epoch_step
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epoch_num), eta_min=0)
    
    online_step = (100000 // 64) +1
    es = utils.EarlyStopping(mode='min', patience=math.floor(epoch_num*0.1) if math.floor(epoch_num*0.1)>35 else 35)
    for epoch in range(1, 1+epoch_num):
        lr = scheduler.get_last_lr()[0]
        #print(lr)
        model.train()
        start_time = time.time()
        train_loss = train_autoencoder(model, trainloader, epoch_step, epoch, epoch_num, lr, optimizer, criterion=criterion)
        
        correct_ops_ave, correct_adj_ave, avg_loss = test_autoencoder(model, testloader, testdataset, epoch, criterion=criterion)
        scheduler.step()
        
        if epoch > 35:
            if es.step(avg_loss):
                print('Early stopping criterion is met, stop training now.')
                print(f'Best result yet {highest_acc}.. VGAE model')
                break

        if lowest_loss is None or avg_loss < lowest_loss:
            lowest_loss = avg_loss
            highest_acc = (correct_ops_ave, correct_adj_ave)
            best_vgae_weight = model.state_dict()
            if save:
                torch.save(best_vgae_weight, os.path.join(save_path, 'vgae_{}.pt'.format(save_label)))
                print(f'Best result yet {highest_acc}.. VGAE model saved.')
            
    print(f'Best result yet {highest_acc}.. VGAE model')
  
def train_pipeline(trainset, testset, train_portion, cluster_eps,
                   num_features=6, num_layers=4, num_hidden=256, num_latent=32, 
                   batch_size=64, train_budget=0, lr=0.00035, save='', design_space=None, n_nodes=0):
    
    # num_train = len(trainset)
    # split = int(np.floor(train_portion * num_train))
    current_trainset = copy.deepcopy(trainset)
    current_testset = copy.deepcopy(testset)
    
    if args.design_space == 'nasbench201':
        design_space = NB201LikeDesignSpace(args=args)
    elif args.design_space == 'nasbench101':
        pass
    elif args.design_space == 'blox':
        design_space = BloxDesignSpace(args=args)
    else:
        design_space = ZenNASDesignSpace(args=args)

        
    train_dataset = utils.UnlabelledGraphDataset(adj_matrix=current_trainset['adj_matrix'].values, 
                                ops_feature=current_trainset['ops_features'].values)

    test_dataset = utils.UnlabelledGraphDataset(adj_matrix=current_testset['adj_matrix'].values, 
                                    ops_feature=current_testset['ops_features'].values)
    
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)                                     
    test_loader=torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

                                            
    gvae = VGAE(num_features=num_features, num_layers=num_layers, num_hidden=num_hidden, num_latent=num_latent, num_node=n_nodes)
    gvae = gvae.cuda()

    autoencoder_train(gvae, train_loader, train_dataset, 
                      test_loader, test_dataset, save_label=str(train_portion),
                      batch_size=batch_size, train_budget=train_budget, lr=lr,save_path=save, num_features=num_features, n_nodes=n_nodes)

parser = argparse.ArgumentParser("vgae")

#devices
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--seed', type=int, default=0, help='random seed')
#dataset
parser.add_argument('--train_portion', type=float, default=1.0, help='dataset portion for training VGAE')
parser.add_argument('--dataset_path', type=str, default="../../exp/benchmarks/nasbench201/sampled_graph_with_vec.pkl", help='dataset path that store all the graph and its zc-vector')
parser.add_argument('--batch_size', type=int, default=64, help='batch size of train loader')

#model parameters
parser.add_argument('--num_features', type=int, default=5, help='number of ops in graph nodes')
parser.add_argument('--num_layers', type=int, default=4, help='number of layers in VAE model')
parser.add_argument('--num_hidden', type=int, default=256, help='len of itermediate feature in VAE')
parser.add_argument('--num_latent', type=int, default=32, help='len of output features from encoder')
parser.add_argument('--ops_activation', type=str, default='sigmoid')
parser.add_argument('--ops_loss', type=str, default='BECLoss')
parser.add_argument('--n_nodes', type=int, default=8)
#train parameters
parser.add_argument('--train_budget', type=int, default=550000)
parser.add_argument('--lr', type=float ,default=0.00035)

#exp 
parser.add_argument('--exp', type=str, default='Test')
parser.add_argument('--exp_root', type=str, default='../../../experiments')

#clustering 
parser.add_argument('--cluster_eps', type=float ,default=1.0)

parser.add_argument('--design_space', type=str, default='ImageNet800M', help='use scala ops features')
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset name')
parser.add_argument('--ss_path', default='', type=str, help='path to ss design space ')

args = parser.parse_args()

if __name__ == '__main__':
    cudnn.benchmark = True
    cudnn.enabled = True
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    exp_path = "{}/{}/step_1/".format(args.exp_root, args.exp)
    os.makedirs(exp_path, exist_ok=True)
    
    #set devices 
    if not torch.cuda.is_available():
        device = torch.device("cpu")  
    else:
        device = utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
        torch.cuda.set_device(device)
    print("Using device", device)

    f = open(args.dataset_path, 'rb')
    data = pickle.load(f)
    dataset = pd.DataFrame(data)

    trainset, testset = train_test_split(dataset, test_size=0.01)
    
    
    num_train = len(trainset)
    indices = list(range(num_train))
    train_pipeline(dataset, testset, args.train_portion, cluster_eps=args.cluster_eps,
                   num_features=args.num_features, num_layers=args.num_layers,
                   num_hidden=args.num_hidden, num_latent=args.num_latent,
                   batch_size=args.batch_size, train_budget=args.train_budget, 
                   lr=args.lr, save=exp_path, design_space=args.design_space, n_nodes=args.n_nodes)