import os
import sys
sys.path.insert(0, '../')
sys.path.insert(0, '../../')
sys.path.insert(0, '../../../')
import time
import numpy as np
import pandas as pd
import pickle
import copy

from tqdm import tqdm
#load torch 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset

import network_designer.utils as utils

from network_designer.design_space.blox.design_space import BloxDesignSpace
from network_designer.design_space.nasbench201.design_space import NB201LikeDesignSpace
from network_designer.design_space.ZenNAS.design_space import ZenNASDesignSpace

import math
#load model 
from network_designer.models.vgae import VGAE
from sklearn.model_selection import train_test_split

import argparse

class VAEReconstructed_Loss(object):
    def __init__(self, w_adj=1.0, w_ops=1.0, w_kl=0.005, loss_adj=nn.BCEWithLogitsLoss, loss_ops=nn.BCEWithLogitsLoss, num_features=6, n_nodes=9):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.w_kl = w_kl
        self.loss_ops = loss_ops()
        self.loss_adj = loss_adj()

    def __call__(self, inputs, targets, mu, logvar):
        adj_recon, ops_recon = inputs[0], inputs[1]
        adj, ops = targets[0], targets[1]
        b,n,l = ops_recon.size()
        loss_adj = self.loss_adj(adj_recon, adj)
        loss_ops = self.loss_ops(ops_recon, ops)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        KLD = -0.5 / (ops.shape[0] * ops.shape[1]) * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 2))
        return loss + (self.w_kl * KLD)
    
class VAEReconstructed_Loss_with_feature_chunks(object):
    def __init__(self, w_adj=1.0, w_ops=1.0, w_kl=0.005, loss_adj=nn.BCEWithLogitsLoss, loss_ops=nn.BCEWithLogitsLoss, feature_chunks_size=None, n_nodes=11):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.w_kl = w_kl
        self.loss_ops = loss_ops()
        self.loss_adj = loss_adj()
        self.ops_losses = []
        self.feature_chunks_size = feature_chunks_size
        if self.feature_chunks_size:
            for i in range(len(self.feature_chunks_size)):
                self.ops_losses.append(nn.BCEWithLogitsLoss().cuda())

    def __call__(self, inputs, targets, mu, logvar):
        adj_recon = inputs[0]
        ops_feature_chunks = inputs[1]
        adj, ops = targets[0], targets[1]
        b, n , _ = ops.size()
        #basically we only want to consider connected node 
        ops_chunks_targets = torch.split(ops, split_size_or_sections=self.feature_chunks_size, dim=2)
        loss_adj = self.loss_adj(adj_recon, adj)
        loss_ops = 0.
        
        for i, (pred, tar) in enumerate(zip(ops_feature_chunks, ops_chunks_targets)):
            loss_ops += self.ops_losses[i](pred, tar)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        KLD = -0.5 / (ops.shape[0] * ops.shape[1]) * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 2))
        return loss + (self.w_kl * KLD)
    
def train_autoencoder(model, trainloader, epoch_step, epoch, num_epochs, lr, optimizer, criterion, scala_dim=None):
    model.train()
    train_loss = 0

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    print('|  Number of Trainable Parameters: ' + str(params))
    print('\n=> Training Epoch #%d, LR=%.4f' % (epoch, lr))
    
    for batch_idx, inputs in enumerate(trainloader):
        optimizer.zero_grad()
        inputs = (inputs[0].cuda(), inputs[1].cuda())
        if batch_idx == 0:
            print(inputs[0].cuda().size())
            print(inputs[1].cuda().size())
        adj_recon, ops_recon, mu, logvar, z = model.forward_decoder(inputs)               # Forward Propagation
        z = z.flatten(start_dim=1)
        loss = criterion((adj_recon, ops_recon), (inputs[0], inputs[1]), mu, logvar)
        loss.backward()  
        nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step() 
        
        if batch_idx == 0:
            if not model.feature_chunks_size:
                print((torch.nn.Sigmoid()(adj_recon[0])>0.5).int())
                print(inputs[0][0].int())
                print((torch.nn.Sigmoid()(ops_recon[0])>0.5).int())
                print(inputs[1][0].int())
        
        try:
            loss.data[0]
        except IndexError:
            loss.data = torch.reshape(loss.data, (1,))
            
        train_loss += loss.data[0]
        
        sys.stdout.write('\r')
        sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f \t\t'
                         % (epoch, num_epochs, batch_idx+1,
                            (epoch_step)+1, loss.data[0]))
        sys.stdout.flush()
        
    return train_loss
        
def test_autoencoder(model, test_loder, test_set, epoch, criterion,feature_chunks_size=None):
    loss = 0
    mean_mse_loss=0
    model.eval()
    correct_ops_ave, mean_correct_adj_ave, mean_false_positive_adj_ave, correct_adj_ave = 0, 0, 0, 0
    for batch_idx, inputs in enumerate(tqdm(test_loder)):
        torch.set_grad_enabled(False)
        N, I, _ = inputs[1].size()
        inputs = (inputs[0].cuda(), inputs[1].cuda())
        adj_recon, ops_recon, mu, logvar, z= model.forward_decoder(inputs)
        loss = criterion((adj_recon, ops_recon), (inputs[0], inputs[1]), mu, logvar)
        if batch_idx == 0:
            if feature_chunks_size:
                targets_chunks =  torch.split(inputs[1].cuda(), split_size_or_sections=model.feature_chunks_size, dim=2)
                #pred_chunks = torch.split(ops_recon, split_size_or_sections=feature_chunks_size, dim=-1)
                for i, (pred, tar) in enumerate(zip(ops_recon, targets_chunks)):
                    print(pred.size())
                    print(tar.size())
                    print(torch.argmax((F.sigmoid(pred)>0.5).int(), dim=-1))
                    print(torch.argmax(tar, dim=-1))
                    print('----------------------------')
                    
            else:
                print((torch.nn.Sigmoid()(adj_recon)>0.5).int())
                print(inputs[0].int())
                print((torch.nn.Sigmoid()(ops_recon)>0.5).int())
                print(inputs[1].int())
    

        torch.set_grad_enabled(True)
        loss += loss.item()

        if feature_chunks_size:
            correct_ops, mean_correct_adj, mean_false_positive_adj, correct_adj = utils.get_accuracy_with_chunk_size((ops_recon, adj_recon), (inputs[1], inputs[0]),N, I,feature_chunks_size)    
        else:
            correct_ops, mean_correct_adj, mean_false_positive_adj, correct_adj = utils.get_accuracy((ops_recon, adj_recon), (inputs[1], inputs[0]))
        
        correct_ops_ave += correct_ops
        mean_correct_adj_ave += mean_correct_adj 
        mean_false_positive_adj_ave += mean_false_positive_adj 
        correct_adj_ave += correct_adj 
    
    avg_loss = loss / len(test_set)
    mean_mse_loss = mean_mse_loss/len(test_set)
    correct_ops_ave = round((correct_ops_ave / len(test_set)), 4)
    mean_correct_adj_ave = round((mean_correct_adj_ave.item() / len(test_set)) , 4)
    mean_false_positive_adj_ave = round((mean_false_positive_adj_ave.item() / len(test_set)) , 4)
    correct_adj_ave = round((correct_adj_ave / len(test_set)), 4)
    
    print(f'/n')
    print(f'Average loss of test set {epoch}: {avg_loss} | Op Acc:{correct_ops_ave} | Mean Correct Adj:{mean_correct_adj_ave} | Mean FP Adj: {mean_false_positive_adj_ave} | Adj Acc:{correct_adj_ave}')
    
    return correct_ops_ave, correct_adj_ave, avg_loss

def autoencoder_train(model, trainloader, traindataset, testloader, 
                      testdataset, save=True, save_label='', save_path='',
                      batch_size=64, train_budget=250, lr=0.00035, feature_chunks_size=None, num_features=0, n_nodes=0):
    print('|  Initial Learning Rate: ' + str(lr))
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
    highest_acc = None
    lowest_loss =None
    if feature_chunks_size:
        criterion = VAEReconstructed_Loss_with_feature_chunks(w_adj=1.0, w_ops=1.0, feature_chunks_size=feature_chunks_size)
    else:
        criterion = VAEReconstructed_Loss(w_adj=1.0, w_ops=1.0, num_features=num_features, n_nodes=n_nodes)
    
    train_iter = train_budget

    epoch_step = (len(traindataset)//64)+1
    epoch_num = train_iter//epoch_step
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epoch_num), eta_min=0)
    es = utils.EarlyStopping(mode='min', patience=math.floor(epoch_num*0.1) if math.floor(epoch_num*0.1)>35 else 35)
    for epoch in range(1, 1+epoch_num):
        lr = scheduler.get_last_lr()[0]
        #print(lr)
        model.train()
        start_time = time.time()
        train_loss = train_autoencoder(model, trainloader, epoch_step, epoch, epoch_num, lr, optimizer, criterion=criterion)
        
        correct_ops_ave, correct_adj_ave, avg_loss = test_autoencoder(model, testloader, testdataset, epoch, criterion=criterion, feature_chunks_size=feature_chunks_size)
        scheduler.step()
        
        if epoch > 35:
            if es.step(avg_loss.cpu()):
                print('Early stopping criterion is met, stop training now.')
                print(f'Best result yet {highest_acc}.. VGAE model')
                break

        if highest_acc is None or (correct_ops_ave >= highest_acc[0] and correct_adj_ave >= highest_acc[1]):
            highest_acc = (correct_ops_ave, correct_adj_ave)
            best_vgae_weight = model.state_dict()
            if save:
                torch.save(best_vgae_weight, os.path.join(save_path, 'vgae_{}'.format(save_label)))
                print(f'Best result yet {highest_acc}.. VGAE model saved.')
            
    print(f'Best result yet {highest_acc}.. VGAE model')
  
def train_pipeline(trainset, testset, train_portion, cluster_eps,
                   num_features=6, num_layers=4, num_hidden=256, num_latent=32, 
                   batch_size=64, train_budget=0, lr=0.00035, save='', design_space=None, n_nodes=0):
    
    # num_train = len(trainset)
    # split = int(np.floor(train_portion * num_train))
    current_trainset = copy.deepcopy(trainset)
    current_testset = copy.deepcopy(testset)
    
    design_space = DARTSDesignSpace(args=args)
    feature_chunks_size = design_space.feature_chunks_size
        
    #train_dataset = utils.OnlineGraphDataset(design_space)
    train_dataset = utils.GraphDataset(adj_matrix=current_trainset['adj_matrix'].values, 
                                ops_feature=current_trainset['ops_features'].values)

    test_dataset = utils.GraphDataset(adj_matrix=current_testset['adj_matrix'].values, 
                                    ops_feature=current_testset['ops_features'].values)
    
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)                                     
    test_loader=torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
    
                                            
    gvae = VGAE(num_features=num_features, num_layers=num_layers, num_hidden=num_hidden, num_latent=num_latent, feature_chunks_size=feature_chunks_size)
    gvae = gvae.cuda()

    autoencoder_train(gvae, train_loader, train_dataset, 
                      test_loader, test_dataset, save_label=str(train_portion),
                      batch_size=batch_size, train_budget=train_budget, lr=lr,save_path=save,feature_chunks_size=feature_chunks_size, num_features=num_features, n_nodes=n_nodes)

parser = argparse.ArgumentParser("vgae")

#devices
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--seed', type=int, default=0, help='random seed')
#dataset
parser.add_argument('--train_portion', type=float, default=1.0, help='dataset portion for training VGAE')
parser.add_argument('--dataset_path', type=str, default="../../exp/benchmarks/nasbench201/sampled_graph_with_vec.pkl", help='dataset path that store all the graph and its zc-vector')
parser.add_argument('--batch_size', type=int, default=64, help='batch size of train loader')

#model parameters
parser.add_argument('--num_features', type=int, default=6, help='number of ops in graph nodes')
parser.add_argument('--num_layers', type=int, default=4, help='number of layers in VAE model')
parser.add_argument('--num_hidden', type=int, default=256, help='len of itermediate feature in VAE')
parser.add_argument('--num_latent', type=int, default=32, help='len of output features from encoder')
parser.add_argument('--ops_activation', type=str, default='sigmoid')
parser.add_argument('--ops_loss', type=str, default='BECLoss')
parser.add_argument('--n_nodes', type=int, default=9)
#train parameters
parser.add_argument('--train_budget', type=int, default=550000)
parser.add_argument('--lr', type=float ,default=0.00035)

#exp 
parser.add_argument('--exp', type=str, default='Test')
parser.add_argument('--exp_root', type=str, default='../../../experiments')

#clustering 
parser.add_argument('--cluster_eps', type=float ,default=1.0)

parser.add_argument('--design_space', type=str, default='ImageNet800M', help='use scala ops features')
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset name')
parser.add_argument('--ss_path', default='', type=str, help='path to ss design space ')

args = parser.parse_args()

if __name__ == '__main__':
    cudnn.benchmark = True
    cudnn.enabled = True
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    exp_path = "{}/{}/step_1/".format(args.exp_root, args.exp)
    os.makedirs(exp_path, exist_ok=True)
    
    #set devices 
    if not torch.cuda.is_available():
        device = torch.device("cpu")  
    else:
        device = utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
        torch.cuda.set_device(device)
    print("Using device", device)

    f = open(args.dataset_path, 'rb')
    data = pickle.load(f)
    dataset = pd.DataFrame(data)

    trainset, testset = train_test_split(dataset, test_size=0.1)
    
    
    num_train = len(trainset)
    indices = list(range(num_train))
    train_pipeline(trainset, testset, args.train_portion, cluster_eps=args.cluster_eps,
                   num_features=args.num_features, num_layers=args.num_layers,
                   num_hidden=args.num_hidden, num_latent=args.num_latent,
                   batch_size=args.batch_size, train_budget=args.train_budget
                   , lr=args.lr, save=exp_path, design_space=args.design_space, n_nodes=args.n_nodes)