import argparse
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import numpy as np
import torch
from train import train_model
from eval import eval_FCSC
from load_data import *
from divide import kfold_split, K_Fold, setup_seed, cross_validate
from transform import *
from bottleneck.loss import SupConLoss
from bottleneck.model import Alternately_Attention_Bottlenecks, Attention_Bottlenecks
from dataloader import DataLoader
from torch.utils.data import Subset
from torch_geometric.data import DenseDataLoader
from torch.optim.lr_scheduler import StepLR
import json
parser = argparse.ArgumentParser(description='FC and SC Classification')
parser.add_argument('--seed', type=int, default=777, help='random seed')
parser.add_argument('--dataset_random_seed', type=int, default=1, help='random seed')
parser.add_argument('--repetitions', type=int, default=10, help='number of repetitions (default: 10)')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
parser.add_argument('--threshold', type=float, default=0.2, help='threshold')
parser.add_argument('--sc_features', type=int, default=90, help='sc_features')
parser.add_argument('--fc_features', type=int, default=90, help='fc_features')
parser.add_argument('--num_classes', type=int, default=2, help='the number of classes (HC/MDD)')
parser.add_argument('--hidden_dim', type=int, default=128, help='hidden size')
parser.add_argument('--dropout', type=float, default=0.3, help='dropout ratio')
parser.add_argument('--num_layers', type=int, default=4, help='the numbers of convolution layers')
parser.add_argument('--fusion_layers', type=int, default=3, help='the numbers of fusion layers')
parser.add_argument('--num_bottlenecks', type=int, default=8, help='the numbers of bottlenecks')
parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs')
parser.add_argument('--dim_feedforward', type=int, default=2048, help='maximum number of epochs')
parser.add_argument('--patience', type=int, default=400, help='patience for early stopping')
parser.add_argument('--dataset', type=str, default='HCP_MBT', help="XX_SCFC/ZD_SCFC/HCP_SCFC")
parser.add_argument('--path', type=str, default='/home/yehongting/yehongting_datasets/multi_zhongda_xinxiang/merge_dataset', help='path of dataset')
parser.add_argument('--result_path', type=str, default='./result/ZDXX.txt', help='path of dataset')
parser.add_argument('--use_cuda', type=bool, default=True, help='specify cuda devices')
parser.add_argument('--temperature', type=float, default=0.03, help='dropout ratio')
parser.add_argument('--negative_weight', type=float, default=0.8, help='dropout ratio')
parser.add_argument('--num_atom_type', type=int, default=90, help='value for num_atom_type')
parser.add_argument('--num_edge_type', type=int, default=90, help='value for num_edge_type')
parser.add_argument('--num_heads', type=int, default=4, help='value for num_heads')
parser.add_argument('--in_feat_dropout', type=float, default=0.5, help='value for in_feat_dropout')
parser.add_argument('--readout', type=str, default='mean', help="mean/sum/max")
parser.add_argument('--layer_norm', type=bool, default=True, help="Please give a value for layer_norm")
parser.add_argument('--batch_norm', type=bool, default=False, help="Please give a value for batch_norm")
parser.add_argument('--residual', type=bool, default=True, help="Please give a value for residual")
# parser.add_argument('--lap_pos_enc', type=bool, default=True, help="Please give a value for lap_pos_enc")
# parser.add_argument('--wl_pos_enc', type=bool, default=False, help="Please give a value for wl_pos_enc")
parser.add_argument('--pos_enc', choices=[None, 'diffusion', 'pstep', 'adj'], default='pstep')
parser.add_argument('--pos_enc_dim', type=int, default=32, help='hidden size')
parser.add_argument('--normalization', choices=[None, 'sym', 'rw'], default='sym',
                        help='normalization for Laplacian')
parser.add_argument('--beta', type=float, default=1.0,
                        help='bandwidth for the diffusion kernel')
parser.add_argument('--p', type=int, default=2, help='p step random walk kernel')
parser.add_argument('--zero_diag', action='store_true', help='zero diagonal for PE matrix')
parser.add_argument('--lappe', action='store_true', help='use laplacian PE',default=True)
parser.add_argument('--lap_dim', type=int, default=32, help='dimension for laplacian PE')
parser.add_argument('--h', type=int, default=1, help='dimension for laplacian PE')
parser.add_argument('--max_nodes_per_hop', type=int, default=5, help='dimension for laplacian PE')
args = parser.parse_args()

params = json.load(open("./param.json"))
for key, item in params.items():
    args.__setattr__(key, item)


params_datasets = json.load(open("./param_dataset.json"))
for key, item in params_datasets[args.dataset].items():
    args.__setattr__(key, item)

if __name__ == '__main__':
    
    
    for threshold in np.arange(0.22, 0.32, 0.02):
        args.threshold = threshold
        
        args.result_path = "./result/ZDXX_threshold_2.txt"
        
        acc = []
        loss = []
        sen = []
        spe = []
        f1 = []
        auc = []
        setup_seed(args.seed)
        
        random_s = np.array([25, 50, 100, 125, 150, 175, 200, 225, 250, 275], dtype=int)
        # random_s = np.array([125], dtype=int)
        print(args)
        for random_seed in random_s:
            # myDataset = FSDataset_GT(args)
            transform = HHopSubgraphs(h=args.h, max_nodes_per_hop=args.max_nodes_per_hop, node_label='hop', use_rd=False, subgraph_pretransform=LapEncoding(dim=4))
            
            args.dataset_random_seed = random_seed
            myDataset = MyOwnDataset("preprocessed_data/{}_{}_{}".format(args.dataset, args.h, threshold), pre_transform=transform, args=args)
            # myDataset = MultiModalDataset(args, pre_transform=transform)

            acc_iter = []
            loss_iter = []
            sen_iter = []
            spe_iter = []
            f1_iter = []
            auc_iter = []
            for i, (train_split, valid_split, test_split) in enumerate(zip(*cross_validate(args.repetitions, myDataset))):
                            
                train_subset, valid_subset, test_subset = myDataset[train_split], myDataset[valid_split], myDataset[test_split]
                
                # train_loader = DenseDataLoader(train_subset, batch_size=args.batch_size, shuffle=True)
                # val_loader = DenseDataLoader(valid_subset, batch_size=args.batch_size, shuffle=False)
                # test_loader = DenseDataLoader(test_subset, batch_size=args.batch_size, shuffle=False)
                train_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=True)
                val_loader = DataLoader(valid_subset, batch_size=args.batch_size, shuffle=False)
                test_loader = DataLoader(test_subset, batch_size=args.batch_size, shuffle=False)

                # Model initialization
                model = Alternately_Attention_Bottlenecks(args)
                if args.use_cuda:
                    model = model.cuda()
                # model = ASAP_multi(args).to(args.device)
                optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
                scheduler = StepLR(optimizer, step_size=50, gamma=0.8)
                # sup_con_loss = SupConLoss()

                # Model training
                best_model = train_model(args, model, optimizer, scheduler, train_loader, val_loader, test_loader, i)

                # Restore model for testing
                model.load_state_dict(torch.load('ckpt/{}/{}_fold_best_model.pth'.format(args.dataset, i)))
                test_acc, test_loss, test_sen, test_spe, test_f1, test_auc,  y, pred = eval_FCSC(args, model, test_loader)
                acc_iter.append(test_acc)
                loss_iter.append(test_loss)
                sen_iter.append(test_sen)
                spe_iter.append(test_spe)
                f1_iter.append(test_f1)
                auc_iter.append(test_auc)
                print('Test set results, best_epoch = {:.1f}  loss = {:.6f}, accuracy = {:.6f}, sensitivity = {:.6f}, '
                    'specificity = {:.6f}, f1_score = {:.6f}, auc_score = {:.6f}'.format(0, test_loss, test_acc, test_sen, test_spe, test_f1, test_auc))
                # with open(args.result_path, 'a+') as f:
                #             f.write("fold:{:04d}  accuracy:{:.6f}     sensitivity:{:.6f}     specificity:{:.6f}     f1_score:{:.6f}     auc_score:{:.6f}\n".format(
                #                 i,test_acc, test_sen,test_spe,test_f1, test_auc))
                # print(y)
                # print(pred)
            acc.append(np.mean(acc_iter))
            sen.append(np.mean(sen_iter))
            spe.append(np.mean(spe_iter))        
            f1.append(np.mean(f1_iter))
            auc.append(np.mean(auc_iter))

            print('Average test set results, accuracy = {:.2f}±{:.2f}, sen = {:.2f}±{:.2f},'
            'spe = {:.2f}±{:.2f}, f1 = {:.2f}±{:.2f}, auc = {:.2f}±{:.2f}'.format(np.mean(acc_iter)*100, np.std(acc_iter)*100, np.mean(sen_iter)*100, np.std(sen_iter)*100,
                                                        np.mean(spe_iter)*100, np.std(spe_iter)*100, np.mean(f1_iter)*100, np.std(f1_iter)*100, np.mean(auc_iter)*100, np.std(auc_iter)*100))
            # with open(args.result_path, 'a+') as f:
            #                 f.write("seed:{:03d} AVERAGE acc:{:.2f}±{:.2f}, sen:{:.2f}±{:.2f}, spe:{:.2f}±{:.2f}, f1:{:.2f}±{:.2f}, auc:{:.2f}±{:.2f}\n".format(
            #                     random_seed,
            #                     np.mean(acc_iter)*100, np.std(acc_iter)*100, np.mean(sen_iter)*100, np.std(sen_iter)*100,
            #                     np.mean(spe_iter)*100, np.std(spe_iter)*100, np.mean(f1_iter)*100, np.std(f1_iter)*100, 
            #                     np.mean(auc_iter)*100, np.std(auc_iter)*100))
        
        print(args)
        print('Total test set results, accuracy : {}'.format(acc))
        print('Average test set results, mean accuracy = {:.6f}, std = {:.6f}, mean_sen = {:.6f}, std_sen = {:.6f}, '
            'mean_spe = {:.6f}, std_spe = {:.6f}, mean_f1 = {:.6f}, std_f1 = {:.6f}, mean_auc = {:.6f}, std_auc = {:.6f}'.format(np.mean(acc), np.std(acc), np.mean(sen), np.std(sen),
                                                        np.mean(spe), np.std(spe), np.mean(f1), np.std(f1), np.mean(auc), np.std(auc)))
        with open(args.result_path, 'a+') as f:
                            f.write("threshold: {:.2f}, Final AVERAGE acc:{:.2f}±{:.2f}, sen:{:.2f}±{:.2f}, spe:{:.2f}±{:.2f}, f1:{:.2f}±{:.2f}, auc:{:.2f}±{:.2f}\n".format(
                                threshold,
                                np.mean(acc)*100, np.std(acc)*100, np.mean(sen)*100, np.std(sen)*100,
                                np.mean(spe)*100, np.std(spe)*100, np.mean(f1)*100, np.std(f1)*100, 
                                np.mean(auc)*100, np.std(auc)*100))