import time
import argparse
import numpy as np
import os.path as osp
from data_utils import *
import torch
import torch.nn.functional as F
import torch.optim as optim
from args import parse_args
import nets.models as models
import data_utils as utils
import data_load
import random
import ipdb
import copy
import statistics
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score, f1_score
import optuna

#from torch.utils.tensorboard import SummaryWriter

# Training setting
args = parse_args()

device='cuda:0'

'''
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
'''
def regularizer(pred,gt,trial,args,imb_map):
    imb_map = imb_map.to(device)
    regularizer_coef = trial.suggest_float(
                    "regularizer_coef",
                    args.reg_coef[0],
                    args.reg_coef[1],
                    log=True
                )
    kl_power = trial.suggest_float(
                    "kl_power",
                    args.kl_power[0],
                    args.kl_power[1],
                    log=True
                )
    
    regularizer = args.regularizer
    unique_values = torch.unique(gt)
    class_ratio = []
    
    for val in (unique_values):
        class_ratio.append((gt == val).sum()/len(gt))
    class_ratio=(torch.tensor(class_ratio)).to(device)
    
    kl_coefs = torch.zeros(len(class_ratio))+2
    kl_coefs [class_ratio>class_ratio.mean()]=0.5
    kl_coefs = (1/kl_coefs)**1
    kl_coefs = kl_coefs.to(device)
    

    if regularizer=='kl':

        pred_mean=F.softmax(pred).mean(dim=0)

        temp=regularizer_coef*(pred_mean*torch.log((pred_mean+1e-7)*len(pred_mean)))
        return temp.sum()
    if regularizer=='kl_reverse':
        
        pred_mean=F.softmax(pred,-1).mean(dim=0)
        # print(imb_map**kl_power)
        temp=regularizer_coef*(pred_mean*torch.log((pred_mean+1e-7)/(imb_map**kl_power)))

        # temp=regularizer_coef*(pred_mean*torch.log((pred_mean+1e-7)*kl_coefs))
        return temp.sum()
    elif regularizer==None:
        
        return 0

def main(trial):
# Load data
    ## Load Dataset ##
    dataset = args.dataset
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset)
    dataset = get_dataset(dataset, path, split_type='public')

    data = dataset[0]
    
    n_cls = data.y.max().item() + 1
    # data = data.to(device)
    repeatition = 1
    seed = 100
    avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], []
    for r in range(repeatition):
        ## Fix seed ##
        torch.cuda.empty_cache()
        seed += 1
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        random.seed(seed)
        np.random.seed(seed)
        if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']:
            idx_train, idx_val, idx_test = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone()
        else:
            idx_train, idx_val, idx_test = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone()
        
        # idx_to_remove=torch.cat([torch.nonzero(data.y[idx_val]==class_id).squeeze(1)[25:] for class_id in range(n_cls)])
        # idx_val[torch.nonzero(idx_val).squeeze(1)[idx_to_remove]]=0
        # idx_to_remove=torch.cat([torch.nonzero(data.y[idx_test]==class_id).squeeze(1)[55:] for class_id in range(n_cls)])
        # idx_test[torch.nonzero(idx_test).squeeze(1)[idx_to_remove]]=0
        
        # raise ValueError(idx_test.sum())
        ## Data statistic ##
        stats = data.y[idx_train]
        n_data = []
        for i in range(n_cls):
            data_num = (stats == i).sum()
            n_data.append(int(data_num.item()))
        idx_info = get_idx_info(data.y, n_cls, idx_train)
        
        class_num_list = n_data

        # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced
        im_class_num = n_cls // 2
        new_class_num_list = []
        max_num = np.max(class_num_list[:n_cls-im_class_num])
        for i in range(n_cls):
            if args.imb_ratio > 1 and i > n_cls-1-im_class_num: #only imbalance the last classes
                new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i]))
            else:
                new_class_num_list.append(class_num_list[i])
        class_num_list = new_class_num_list

        if args.imb_ratio > 1:
            idx_train, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device)
        
        class_num_mat = torch.zeros((n_cls, 3))
        class_num_mat[:,0] = torch.tensor([data.y[idx_train].eq(i).sum() for i in range(n_cls)])
        class_num_mat[:,1] = torch.tensor([data.y[idx_val].eq(i).sum() for i in range(n_cls)])
        class_num_mat[:,2] = torch.tensor([data.y[idx_test].eq(i).sum() for i in range(n_cls)])
        # raise ValueError(class_num_mat)
        idx_train=torch.nonzero(idx_train).squeeze(1)
        idx_val=torch.nonzero(idx_val).squeeze(1)
        idx_test=torch.nonzero(idx_test).squeeze(1)
        # raise ValueError(idx_val.shape)
        adj, features, labels=data_load.load_data(data)
#method_1: oversampling in input domain
        if args.setting == 'upsampling':
            adj,features,labels,idx_train = utils.src_upsample(adj,features,labels,idx_train,portion=args.up_scale, im_class_num=im_class_num)
        if args.setting == 'smote':
            adj,features,labels,idx_train = utils.src_smote(adj,features,labels,idx_train,portion=args.up_scale, im_class_num=im_class_num)
        # raise ValueError(data.device)

        

# Model and optimizer
#if oversampling in the embedding space is required, model need to be changed
#If oversampling in embedded space is required, a two layer model will be used.
        if args.setting != 'embed_up':
            if args.model == 'sage':
                encoder = models.Sage_En(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Sage_Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'gcn':
                encoder = models.GCN_En(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.GCN_Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'GAT':
                encoder = models.GAT_En(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.GAT_Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
        else:
            if args.model == 'sage':
                encoder = models.Sage_En2(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'gcn':
                encoder = models.GCN_En2(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'GAT':
                encoder = models.GAT_En2(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)



        decoder = models.Decoder(nembed=args.nhid,
                dropout=args.dropout)


        optimizer_en = optim.Adam(encoder.parameters(),
                            lr=args.lr, weight_decay=args.weight_decay)
        optimizer_cls = optim.Adam(classifier.parameters(),
                            lr=args.lr, weight_decay=args.weight_decay)
        optimizer_de = optim.Adam(decoder.parameters(),
                            lr=args.lr, weight_decay=args.weight_decay)


#Transfer to cuda
        
        encoder = encoder.cuda()
        classifier = classifier.cuda()
        decoder = decoder.cuda()
        features = features.cuda()
        adj = adj.cuda()
        labels = labels.cuda()
        idx_train = idx_train.cuda()
        idx_val = idx_val.cuda()
        idx_test = idx_test.cuda()
        
#Train the model which will be used further down
        def train(epoch):
            
            t = time.time()
            encoder.train()
            classifier.train()
            decoder.train()
            optimizer_en.zero_grad()
            optimizer_cls.zero_grad()
            optimizer_de.zero_grad()

            embed = encoder(features, adj)

            if args.setting == 'recon_newG' or args.setting == 'recon' or args.setting == 'newG_cls':
                ori_num = labels.shape[0]
                embed, labels_new, idx_train_new, adj_up = utils.recon_upsample(embed, labels, idx_train, adj=adj.detach().to_dense(),portion=args.up_scale, im_class_num=im_class_num)
                # raise ValueError((labels_new[idx_train_new]==6).sum())
                generated_G = decoder(embed)

                loss_rec = utils.adj_mse_loss(generated_G[:ori_num, :][:, :ori_num], adj.detach().to_dense())
                
                #ipdb.set_trace()


                if not args.opt_new_G:
                    adj_new = copy.deepcopy(generated_G.detach())
                    threshold = 0.5
                    adj_new[adj_new<threshold] = 0.0
                    adj_new[adj_new>=threshold] = 1.0

                    #ipdb.set_trace()
                    edge_ac = adj_new[:ori_num, :ori_num].eq(adj.to_dense()).double().sum()/(ori_num**2)
                else:
                    adj_new = generated_G
                    edge_ac = F.l1_loss(adj_new[:ori_num, :ori_num], adj.to_dense(), reduction='mean')


                #calculate generation information
                exist_edge_prob = adj_new[:ori_num, :ori_num].mean() #edge prob for existing nodes
                generated_edge_prob = adj_new[ori_num:, :ori_num].mean() #edge prob for generated nodes
                


                adj_new = torch.mul(adj_up, adj_new)

                exist_edge_prob = adj_new[:ori_num, :ori_num].mean() #edge prob for existing nodes
                generated_edge_prob = adj_new[ori_num:, :ori_num].mean() #edge prob for generated nodes
                


                adj_new[:ori_num, :][:, :ori_num] = adj.detach().to_dense()
                #adj_new = adj_new.to_sparse()
                #ipdb.set_trace()

                if not args.opt_new_G:
                    adj_new = adj_new.detach()

                if args.setting == 'newG_cls':
                    idx_train_new = idx_train

            elif args.setting == 'embed_up':
                #perform SMOTE in embedding space
                embed, labels_new, idx_train_new = utils.recon_upsample(embed, labels, idx_train,portion=args.up_scale, im_class_num=im_class_num)
                adj_new = adj
            else:
                labels_new = labels
                idx_train_new = idx_train
                adj_new = adj

            #ipdb.set_trace()
            output = classifier(embed, adj_new)



            if args.setting == 'reweight':
                weight = features.new((labels.max().item()+1)).fill_(1)
                weight[-im_class_num:] = 1+args.up_scale
                loss_train = F.cross_entropy(c, labels_new[idx_train_new], weight=weight)
            else:
                loss_train = F.cross_entropy(output[idx_train_new], labels_new[idx_train_new])

            acc_train = utils.accuracy(output[idx_train], labels_new[idx_train])
            if args.setting == 'recon_newG':
                loss = loss_train+loss_rec*args.rec_weight
            elif args.setting == 'recon':
                loss = loss_rec + 0*loss_train
            else:
                loss = loss_train
                loss_rec = loss_train
            imb_map = torch.tensor([1/(labels_new[idx_train_new]==cl_id).sum().item() for cl_id in data.y.unique()])
            imb_map = imb_map/imb_map.sum()
            loss += regularizer(output[idx_train_new],labels_new[idx_train_new],trial,args,imb_map)
            loss.backward()
            if args.setting == 'newG_cls':
                optimizer_en.zero_grad()
                optimizer_de.zero_grad()
            else:
                optimizer_en.step()

            optimizer_cls.step()

            if args.setting == 'recon_newG' or args.setting == 'recon':
                optimizer_de.step()

            loss_val = F.cross_entropy(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])

            # #ipdb.set_trace()
            # utils.print_class_acc(output[idx_val], labels[idx_val], class_num_mat[:,1])

            

        def test(epoch = 0):
            encoder.eval()
            classifier.eval()
            decoder.eval()
            embed = encoder(features, adj)
            output = classifier(embed, adj)
            accs, baccs, f1s = [], [], []
            for i, mask in enumerate([idx_train, idx_val]):
                pred = output[mask].max(1)[1].cpu()
                # raise ValueError(pred.shape)
                y_pred = pred.cpu().numpy()
                y_true = labels[mask.cpu()].cpu().numpy()
                acc = pred.eq(labels.cpu()[mask.cpu()]).sum().item() / mask.cpu().sum().item()
                bacc = balanced_accuracy_score(y_true, y_pred)
                
                f1 = f1_score(y_true, y_pred, average='macro')

                accs.append(acc)
                baccs.append(bacc)
                f1s.append(f1)

            return accs, baccs, f1s
            
            

            


        def save_model(epoch):
            saved_content = {}

            saved_content['encoder'] = encoder.state_dict()
            saved_content['decoder'] = decoder.state_dict()
            saved_content['classifier'] = classifier.state_dict()

            torch.save(saved_content, 'smote_checkpoints/{}/{}_{}_{}_{}.pth'.format(args.dataset,args.setting,epoch, args.opt_new_G, args.im_ratio))

            return

        def load_model(filename):
            loaded_content = torch.load('smote_checkpoints/{}/{}.pth'.format(args.dataset,filename), map_location=lambda storage, loc: storage)

            encoder.load_state_dict(loaded_content['encoder'])
            decoder.load_state_dict(loaded_content['decoder'])
            classifier.load_state_dict(loaded_content['classifier'])

            print("successfully loaded: "+ filename)

            pass

        # Train models
        best_val_acc_f1 = 0
        if args.load is not None:
            load_model(args.load)
        
        patient=0
        t_total = time.time()
        for epoch in range(1, 1000):

            train(epoch)
            accs, baccs, f1s = test()
            
            train_acc, val_acc = accs
            train_f1, val_f1 = f1s
            val_acc_f1 = (val_acc + val_f1) / 2

            if val_acc_f1 > best_val_acc_f1:
                best_val_acc_f1 = val_acc_f1
                if epoch % 100 == 0:
                    save_model(0)
                patient=0
            else:
                patient+=1
            if patient>100:
                break
            

        avg_val_acc_f1.append(best_val_acc_f1)
        print(avg_val_acc_f1)


    avg_val_acc_f1 = statistics.mean(avg_val_acc_f1)

    

    return avg_val_acc_f1

def detailed_main(trial):
# Load data
    ## Load Dataset ##
    dataset = args.dataset
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset)
    dataset = get_dataset(dataset, path, split_type='public')

    data = dataset[0]
    
    n_cls = data.y.max().item() + 1
    # data = data.to(device)
    if args.setting!='recon':
        repeatition = 10
    else:
        repeatition = 1
    seed = 100
    avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], []
    for r in range(repeatition):
        ## Fix seed ##
        torch.cuda.empty_cache()
        seed += 1
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        random.seed(seed)
        np.random.seed(seed)
        if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']:
            idx_train, idx_val, idx_test = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone()
        else:
            idx_train, idx_val, idx_test = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone()
        
        # idx_to_remove=torch.cat([torch.nonzero(data.y[idx_val]==class_id).squeeze(1)[25:] for class_id in range(n_cls)])
        # idx_val[torch.nonzero(idx_val).squeeze(1)[idx_to_remove]]=0
        # idx_to_remove=torch.cat([torch.nonzero(data.y[idx_test]==class_id).squeeze(1)[55:] for class_id in range(n_cls)])
        # idx_test[torch.nonzero(idx_test).squeeze(1)[idx_to_remove]]=0
        
        # raise ValueError(idx_test.sum())
        ## Data statistic ##
        stats = data.y[idx_train]
        n_data = []
        for i in range(n_cls):
            data_num = (stats == i).sum()
            n_data.append(int(data_num.item()))
        idx_info = get_idx_info(data.y, n_cls, idx_train)
        
        class_num_list = n_data

        # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced
        im_class_num = n_cls // 2
        new_class_num_list = []
        max_num = np.max(class_num_list[:n_cls-im_class_num])
        for i in range(n_cls):
            if args.imb_ratio > 1 and i > n_cls-1-im_class_num: #only imbalance the last classes
                new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i]))
            else:
                new_class_num_list.append(class_num_list[i])
        class_num_list = new_class_num_list

        if args.imb_ratio > 1:
            idx_train, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device)
        
        class_num_mat = torch.zeros((n_cls, 3))
        class_num_mat[:,0] = torch.tensor([data.y[idx_train].eq(i).sum() for i in range(n_cls)])
        class_num_mat[:,1] = torch.tensor([data.y[idx_val].eq(i).sum() for i in range(n_cls)])
        class_num_mat[:,2] = torch.tensor([data.y[idx_test].eq(i).sum() for i in range(n_cls)])
        # raise ValueError(class_num_mat)
        idx_train=torch.nonzero(idx_train).squeeze(1)
        idx_val=torch.nonzero(idx_val).squeeze(1)
        idx_test=torch.nonzero(idx_test).squeeze(1)
        # raise ValueError(idx_val.shape)
        adj, features, labels=data_load.load_data(data)
#method_1: oversampling in input domain
        if args.setting == 'upsampling':
            adj,features,labels,idx_train = utils.src_upsample(adj,features,labels,idx_train,portion=args.up_scale, im_class_num=im_class_num)
        if args.setting == 'smote':
            adj,features,labels,idx_train = utils.src_smote(adj,features,labels,idx_train,portion=args.up_scale, im_class_num=im_class_num)
        # raise ValueError(data.device)

        

# Model and optimizer
#if oversampling in the embedding space is required, model need to be changed
#If oversampling in embedded space is required, a two layer model will be used.
        if args.setting != 'embed_up':
            if args.model == 'sage':
                encoder = models.Sage_En(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Sage_Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'gcn':
                encoder = models.GCN_En(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.GCN_Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'GAT':
                encoder = models.GAT_En(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.GAT_Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
        else:
            if args.model == 'sage':
                encoder = models.Sage_En2(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'gcn':
                encoder = models.GCN_En2(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)
            elif args.model == 'GAT':
                encoder = models.GAT_En2(nfeat=features.shape[1],
                        nhid=args.nhid,
                        nembed=args.nhid,
                        dropout=args.dropout)
                classifier = models.Classifier(nembed=args.nhid, 
                        nhid=args.nhid, 
                        nclass=labels.max().item() + 1, 
                        dropout=args.dropout)



        decoder = models.Decoder(nembed=args.nhid,
                dropout=args.dropout)


        optimizer_en = optim.Adam(encoder.parameters(),
                            lr=args.lr, weight_decay=args.weight_decay)
        optimizer_cls = optim.Adam(classifier.parameters(),
                            lr=args.lr, weight_decay=args.weight_decay)
        optimizer_de = optim.Adam(decoder.parameters(),
                            lr=args.lr, weight_decay=args.weight_decay)


#Transfer to cuda
        
        encoder = encoder.cuda()
        classifier = classifier.cuda()
        decoder = decoder.cuda()
        features = features.cuda()
        adj = adj.cuda()
        labels = labels.cuda()
        idx_train = idx_train.cuda()
        idx_val = idx_val.cuda()
        idx_test = idx_test.cuda()
        
#Train the model which will be used further down
        def train(epoch):
            
            t = time.time()
            encoder.train()
            classifier.train()
            decoder.train()
            optimizer_en.zero_grad()
            optimizer_cls.zero_grad()
            optimizer_de.zero_grad()

            embed = encoder(features, adj)

            if args.setting == 'recon_newG' or args.setting == 'recon' or args.setting == 'newG_cls':
                ori_num = labels.shape[0]
                embed, labels_new, idx_train_new, adj_up = utils.recon_upsample(embed, labels, idx_train, adj=adj.detach().to_dense(),portion=args.up_scale, im_class_num=im_class_num)
                # raise ValueError(labels_new)
                generated_G = decoder(embed)

                loss_rec = utils.adj_mse_loss(generated_G[:ori_num, :][:, :ori_num], adj.detach().to_dense())
                
                #ipdb.set_trace()


                if not args.opt_new_G:
                    adj_new = copy.deepcopy(generated_G.detach())
                    threshold = 0.5
                    adj_new[adj_new<threshold] = 0.0
                    adj_new[adj_new>=threshold] = 1.0

                    #ipdb.set_trace()
                    edge_ac = adj_new[:ori_num, :ori_num].eq(adj.to_dense()).double().sum()/(ori_num**2)
                else:
                    adj_new = generated_G
                    edge_ac = F.l1_loss(adj_new[:ori_num, :ori_num], adj.to_dense(), reduction='mean')


                #calculate generation information
                exist_edge_prob = adj_new[:ori_num, :ori_num].mean() #edge prob for existing nodes
                generated_edge_prob = adj_new[ori_num:, :ori_num].mean() #edge prob for generated nodes
                


                adj_new = torch.mul(adj_up, adj_new)

                exist_edge_prob = adj_new[:ori_num, :ori_num].mean() #edge prob for existing nodes
                generated_edge_prob = adj_new[ori_num:, :ori_num].mean() #edge prob for generated nodes
                


                adj_new[:ori_num, :][:, :ori_num] = adj.detach().to_dense()
                #adj_new = adj_new.to_sparse()
                #ipdb.set_trace()

                if not args.opt_new_G:
                    adj_new = adj_new.detach()

                if args.setting == 'newG_cls':
                    idx_train_new = idx_train

            elif args.setting == 'embed_up':
                #perform SMOTE in embedding space
                embed, labels_new, idx_train_new = utils.recon_upsample(embed, labels, idx_train,portion=args.up_scale, im_class_num=im_class_num)
                adj_new = adj
            else:
                labels_new = labels
                idx_train_new = idx_train
                adj_new = adj

            #ipdb.set_trace()
            output = classifier(embed, adj_new)



            if args.setting == 'reweight':
                weight = features.new((labels.max().item()+1)).fill_(1)
                weight[-im_class_num:] = 1+args.up_scale
                loss_train = F.cross_entropy(c, labels_new[idx_train_new], weight=weight)
            else:
                loss_train = F.cross_entropy(output[idx_train_new], labels_new[idx_train_new])

            acc_train = utils.accuracy(output[idx_train], labels_new[idx_train])
            if args.setting == 'recon_newG':
                loss = loss_train+loss_rec*args.rec_weight
            elif args.setting == 'recon':
                loss = loss_rec + 0*loss_train
            else:
                loss = loss_train
                loss_rec = loss_train
            imb_map = torch.tensor([1/(labels_new[idx_train_new]==cl_id).sum().item() for cl_id in data.y.unique()])
            imb_map = imb_map/imb_map.sum()
            loss += regularizer(output[idx_train_new],labels_new[idx_train_new],trial,args,imb_map)
            loss.backward()
            if args.setting == 'newG_cls':
                optimizer_en.zero_grad()
                optimizer_de.zero_grad()
            else:
                optimizer_en.step()

            optimizer_cls.step()

            if args.setting == 'recon_newG' or args.setting == 'recon':
                optimizer_de.step()

            loss_val = F.cross_entropy(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])

            # #ipdb.set_trace()
            utils.print_class_acc(output[idx_val], labels[idx_val], class_num_mat[:,1])

            

        def test(epoch = 0):
            encoder.eval()
            classifier.eval()
            decoder.eval()
            embed = encoder(features, adj)
            output = classifier(embed, adj)
            accs, baccs, f1s = [], [], []
            for i, mask in enumerate([idx_val,idx_test]):
                pred = output[mask].max(1)[1].cpu()
                y_pred = pred.cpu().numpy()
                y_true = data.y[mask.cpu()].cpu().numpy()
                acc = pred.eq(data.y[mask.cpu()]).sum().item() / mask.cpu().sum().item()
                bacc = balanced_accuracy_score(y_true, y_pred)
                
                f1 = f1_score(y_true, y_pred, average='macro')

                accs.append(acc)
                baccs.append(bacc)
                f1s.append(f1)

            return accs, baccs, f1s
            
            

            


        def save_model(epoch,model_name):
            saved_content = {}

            saved_content['encoder'] = encoder.state_dict()
            saved_content['decoder'] = decoder.state_dict()
            saved_content['classifier'] = classifier.state_dict()

            torch.save(saved_content, 'smote_checkpoints/{}/{}_{}_{}_{}_{}.pth'.format(args.dataset,model_name,args.setting,epoch, args.opt_new_G, args.im_ratio))

            return

        def load_model(filename):
            loaded_content = torch.load('smote_checkpoints/{}/{}.pth'.format(args.dataset,filename), map_location=lambda storage, loc: storage)

            encoder.load_state_dict(loaded_content['encoder'])
            decoder.load_state_dict(loaded_content['decoder'])
            classifier.load_state_dict(loaded_content['classifier'])

            print("successfully loaded: "+ filename)

            pass

        # Train models
        best_val_acc_f1 = 0
        if args.load is not None:
            load_model(args.load)
        
        patient=0
        t_total = time.time()
        for epoch in range(1, 2000):

            train(epoch)
            accs, baccs, f1s = test()
            
            train_acc, val_acc = accs
            train_f1, val_f1 = f1s
            val_acc_f1 = (val_acc + val_f1) / 2
            if epoch==1900:
                save_model(0,args.model)
            if val_acc_f1 > best_val_acc_f1:
                best_val_acc_f1 = val_acc_f1
                test_acc = accs[1]
                test_bacc = baccs[1]
                test_f1 = f1s[1]
                
                
                patient=0
            else:
                patient+=1
            # if patient>100:
            #     break
        avg_val_acc_f1.append(best_val_acc_f1)
        avg_test_acc.append(test_acc)
        avg_test_bacc.append(test_bacc)
        avg_test_f1.append(test_f1)

    ## Calculate statistics ##
    regularizer_coef = trial.suggest_float(
                    "regularizer_coef",
                    args.reg_coef[0],
                    args.reg_coef[1],
                    log=True
                )
    kl_power = trial.suggest_float(
                    "kl_power",
                    args.kl_power[0],
                    args.kl_power[1],
                    log=True
                )
    with open('./graphsmote.md', 'a') as file:
        file.write(f'dataset:{args.dataset},net:{args.model},tam:{args.tam},regularizer:{args.regularizer},imb_ratio:{args.imb_ratio}\n,regularizer_coef:{regularizer_coef:.4f},kl_power:{kl_power:.4f}\n')
        file.write(f'avg_test_acc: {[float(f"{num:.4f}") for num in avg_test_acc]}\n')
        file.write(f'avg_test_bacc: {[float(f"{num:.4f}") for num in avg_test_bacc]}\n')
        file.write(f'avg_test_f1: {[float(f"{num:.4f}") for num in avg_test_f1]}\n\n\n')

        avg_val_acc_f1.append(best_val_acc_f1)
        print(avg_val_acc_f1)


    avg_val_acc_f1 = statistics.mean(avg_val_acc_f1)

    

    return avg_val_acc_f1

if __name__ == '__main__':
    name = ''
    
        
    for arg_name, arg_value in vars(args).items():
        if arg_name in ['num_trial','dataset','imb_ratio','net','tam','ens','regularizer','run_type','repeatition']:
            name += f'{arg_name}: {arg_value}, '
    study = optuna.create_study(direction="maximize")
    study.optimize(
        main,
        n_trials=args.num_trial,
    )
    
    trial = study.best_trial
    detailed_main(trial)
    