import os
# print("SLURM_JOB_ID:{0}".format(os.environ["SLURM_JOB_ID"]))
import sys
import math
import time
import argparse
import torch
import ast
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from copy import deepcopy
from datetime import datetime
from sklearn.metrics import roc_auc_score,average_precision_score,roc_curve,RocCurveDisplay,precision_recall_curve,PrecisionRecallDisplay

sys.path.append("/set/your/path/")

from pycls.al.ActiveLearning import ActiveLearning
import pycls.core.builders as model_builder
from pycls.core.config import cfg, dump_cfg
import pycls.core.losses as losses
import pycls.core.optimizer as optim
from pycls.datasets.data import Data
import pycls.utils.checkpoint as cu
import pycls.utils.logging as lu
import pycls.utils.metrics as mu
import pycls.utils.net as nu
from pycls.utils.meters import TestMeter
from pycls.utils.meters import TrainMeter
from pycls.utils.meters import ValMeter
import pycls.datasets.utils as ds_utils
from sklearn.metrics import matthews_corrcoef
from pycls.datasets.custom_datasets import Q1_data

# a=time.time()
# print(time.time()-a)

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def argparser():
    parser = argparse.ArgumentParser(description='Active Learning - Image Classification')
    parser.add_argument('--cfg', default='/set/path/of/configs/template.yaml', help='Config file', type=str)
    parser.add_argument('--exp-name', default='name_of_exp', type=str)  
    parser.add_argument('--initial_size', default=0, type=int)
    parser.add_argument('--seed', help='Random seed', default=1, type=int)
    parser.add_argument('--finetune', help='Whether to continue with existing model between rounds', type=str2bool, default=True)
    parser.add_argument('--k_logistic', default=50, type=int)
    parser.add_argument('--a_logistic', default=0.8, type=float)
    parser.add_argument('--ratio', help='Whether to use ratio strategy', default=False, type=str2bool)
    parser.add_argument('--prob', help='Whether to use prob strategy', default=False, type=str2bool)
    # parser.add_argument('--onenn', help='Whether to use a 1-n-n classifier', default=True, type=str2bool)
    # parser.add_argument('--linear_from_features', help='Whether to use a linear layer from self-supervised features', action='store_true')
    parser.add_argument('--linear_from_features', help='Whether to use a linear layer from self-supervised features', default=True)
    parser.add_argument('--logging', help='logging', default=False, type=str2bool)

    parser.add_argument('--initial_delta', help='Relevant only for ProbCover and DCoM', default=0.75, type=float)  
    parser.add_argument('--budget', default=10, type=int)
    parser.add_argument('--dataset', default='phishing', type=str)    #  TRPB_balanced,  'optdigits', 'phishing', CIFAR10
    parser.add_argument('--data_type', default='embed', type=str)   #  
    parser.add_argument('--num_classes', default=10, type=int)
    parser.add_argument('--al', default='activesilhouette', type=str)   # activesilhouette
    parser.add_argument('--random_initial', default=False, type=str2bool)   # whether to random select samples as initial labelled pool
    parser.add_argument('--medoid_initial', default=True, type=str2bool)   # whether to medoid select samples as initial labelled pool
    parser.add_argument('--pure_nn', default=True, type=str2bool)   # whether to use 1nn or silhou_1nn
    parser.add_argument('--random_seed', default=1, type=int)
    parser.add_argument('--maxiter', default=9, type=int)
    parser.add_argument('--fix_gamma', default=False, type=str2bool)

    return parser



def is_eval_epoch(cur_epoch):
    """Determines if the model should be evaluated at the current epoch."""
    return (
        (cur_epoch + 1) % cfg.TRAIN.EVAL_PERIOD == 0 or
        (cur_epoch + 1) == cfg.OPTIM.MAX_EPOCH
    )

@torch.no_grad()
def silh_dist( x,y, gamma=0.1):
        # d = (1-torch.exp(-gamma * (torch.cdist(x, y)**2)))**0.5 
        d = torch.cdist(x, y)
        # d[d<gamma] = 0
        # return d 
        return (d-gamma).clamp(min=0) 

@torch.no_grad()
def one_nn_silhou(cfg, ufeat, lfeat, ufeat_labels, lfeat_labels, best_gamma, p=2, batch_size=8192, device='cuda'): 
    class_pred = [] 
    for i in range(math.ceil(ufeat.shape[0]/batch_size)): 
        d_lu = torch.cdist(lfeat, ufeat[i * batch_size: (i + 1) * batch_size], p=p)
        # d_lu = silh_dist(lfeat, ufeat[i * batch_size: (i + 1) * batch_size], gamma=best_gamma)
        sample_pred = torch.argmin(d_lu, dim=0) 
        class_pred.append(lfeat_labels[sample_pred])
    pred = torch.hstack(class_pred)
    return ((pred==ufeat_labels).sum()/ufeat.shape[0]).item(), matthews_corrcoef(ufeat_labels.cpu().numpy(), pred.cpu().numpy())



@torch.no_grad()
def one_nn(cfg, ufeat, lfeat, ufeat_labels, lfeat_labels, p=2, batch_size=8192, device='cuda'): 
    class_pred = [] 
    for i in range(math.ceil(ufeat.shape[0]/batch_size)): 
        d_lu = torch.cdist(lfeat, ufeat[i * batch_size: (i + 1) * batch_size], p=p)
        sample_pred = torch.argmin(d_lu, dim=0) 
        class_pred.append(lfeat_labels[sample_pred])
    pred = torch.hstack(class_pred)
    return ((pred==ufeat_labels).sum()/ufeat.shape[0]).item(), matthews_corrcoef(ufeat_labels.cpu().numpy(), pred.cpu().numpy())


def main(cfg):
    if cfg.logging:
        logger = lu.get_logger(__name__)
    # Setting up GPU args
    use_cuda = (cfg.NUM_GPUS > 0) and torch.cuda.is_available()
    # device = torch.device("cuda" if use_cuda else "cpu")
    device  = torch.device("cuda")
    kwargs = {'num_workers': cfg.DATA_LOADER.NUM_WORKERS, 'pin_memory': cfg.DATA_LOADER.PIN_MEMORY} if use_cuda else {}
    # Auto assign a RNG_SEED when not supplied a value
    if cfg.RNG_SEED is None:
        cfg.RNG_SEED = np.random.randint(100)

    out_dir = "/set/your/out/dir"
    dataset_out_dir = os.path.join(out_dir, cfg.DATASET.NAME, cfg.MODEL.TYPE)
    if not os.path.exists(dataset_out_dir):
        os.makedirs(dataset_out_dir)
  
    if cfg.EXP_NAME == 'auto':
        now = datetime.now()
        exp_dir = f'{now.year}_{now.month}_{now.day}_{now.hour:02}{now.minute:02}{now.second:02}_{now.microsecond}'
    else:
        exp_dir = cfg.EXP_NAME

    exp_dir = os.path.join(dataset_out_dir, exp_dir)
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)
        print("Experiment Directory is {}.\n".format(exp_dir))
    else:
        print("Experiment Directory Already Exists: {}. Reusing it may lead to loss of old logs in the directory.\n".format(exp_dir))
    cfg.EXP_DIR = exp_dir
    # Save the config file in EXP_DIR
    dump_cfg(cfg)
    # Setup Logger
    lu.setup_logging(cfg)

    # print("\n======== PREPARING DATA AND MODEL ========\n")
    cfg.DATASET.ROOT_DIR = "/set/your/data/dir"
    data_obj = Data(cfg)

    if cfg.DATASET.NAME in ['CIFAR10']:
        train_data, train_size = data_obj.getDataset(save_dir=cfg.DATASET.ROOT_DIR, isTrain=True, isDownload=True)  
        # test_data, test_size = data_obj.getDataset(save_dir=cfg.DATASET.ROOT_DIR, isTrain=False, isDownload=True)
        cfg.ACTIVE_LEARNING.INIT_L_RATIO = args.initial_size / train_size
        print("\nDataset {} Loaded Sucessfully.\nTotal Train Size: {} \n".format(cfg.DATASET.NAME, train_size))
        if cfg.logging:
            logger.info("Dataset {} Loaded Sucessfully. Total Train Size: {} \n".format(cfg.DATASET.NAME, train_size))

        torch.manual_seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        n_dataPoints = len(train_data)     # 50000
        all_idx = [i for i in range(n_dataPoints)]
        np.random.shuffle(all_idx)
        if args.random_initial:
            # lSet_length = cfg.ACTIVE_LEARNING.BUDGET_SIZE
            lSet_length = 40
        else: 
            lSet_length = int(cfg.ACTIVE_LEARNING.INIT_L_RATIO *n_dataPoints)  # 0
        train_splitIdx = lSet_length 

        val_splitIdx = int((1- cfg.DATASET.VAL_RATIO)*n_dataPoints)   
        lSet = all_idx[:train_splitIdx]
        # print("initial pool:", lSet)
        uSet = all_idx[train_splitIdx:val_splitIdx]
        valSet = all_idx[val_splitIdx:]
        lSet = np.array(lSet, dtype=np.ndarray)
        uSet = np.array(uSet, dtype=np.ndarray)
        valSet = np.array(valSet, dtype=np.ndarray)

    elif cfg.DATASET.NAME in ['CIFAR10_imbalanced', 'CIFAR10_all_imbalanced', 'TRPB_balanced', 'optdigits', 'phishing']:
        if cfg.DATASET.NAME in ['phishing']:
            train_data, train_label = ds_utils.load_features_labelexclu(cfg.DATASET.NAME, seed=1, train=True, normalized=True)    # normalized = True when phishing; False for others
            test_data, test_label = ds_utils.load_features_labelexclu(cfg.DATASET.NAME,  seed=1, train=False, normalized=True)
        else:
            train_data, train_label = ds_utils.load_features_labelexclu(cfg.DATASET.NAME, seed=1, train=True, normalized=False)    # normalized = True when phishing; False for others
            test_data, test_label = ds_utils.load_features_labelexclu(cfg.DATASET.NAME,  seed=1, train=False, normalized=False)
        print('train_data:' , train_data.shape)
        print('test_data:' , test_data.shape)
        torch.manual_seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        n_dataPoints = len(train_data)
        all_idx = [i for i in range(n_dataPoints)]
        np.random.shuffle(all_idx)
        if args.random_initial:
            lSet_length = cfg.ACTIVE_LEARNING.BUDGET_SIZE
        else: 
            lSet_length = int(cfg.ACTIVE_LEARNING.INIT_L_RATIO *n_dataPoints)
        train_splitIdx = lSet_length
        
        
        # train_splitIdx = int(cfg.ACTIVE_LEARNING.INIT_L_RATIO *n_dataPoints)
        lSet = all_idx[:train_splitIdx]
        # print("initial pool:", lSet)
        uSet = all_idx[train_splitIdx:]
        lSet = np.array(lSet, dtype=np.ndarray)
        uSet = np.array(uSet, dtype=np.ndarray)
        valSet = []

    model = model_builder.build_model(cfg).to(device)

    if len(lSet) == 0:

        al_obj = ActiveLearning(data_obj, cfg)
        if cfg.medoid_initial:
            activeSet, new_uSet = al_obj.sample_initial_medoid(lSet, uSet)
            print(f'Initial Pool is {activeSet}')
            lSet = np.append(lSet, activeSet)
            uSet = new_uSet
            best_gamma = None
            dele_sample = []
        else:
            dele_sample = []
            if cfg.ACTIVE_LEARNING.SAMPLING_FN.lower() in ["activesilhouette"]:
                activeSet, new_uSet,  best_gamma = al_obj.sample_from_uSet(model, lSet, uSet, train_data, dele_sample)
            else:
                activeSet, new_uSet = al_obj.sample_from_uSet(model, lSet, uSet, train_data, dele_sample)
            print(f'Initial Pool is {activeSet}')
            if cfg.logging:
                logger.info(f'Active set is {activeSet}')

            lSet = np.append(lSet, activeSet)
            uSet = new_uSet

        

    # print("Data Partitioning Complete. \nLabeled Set: {}, Unlabeled Set: {}, Validation Set: {}\n".format(len(lSet), len(uSet), len(valSet)))
    if cfg.logging:
        logger.info("Labeled Set: {}, Unlabeled Set: {}, Validation Set: {}\n".format(len(lSet), len(uSet), len(valSet)))

    # print("model: 1-n-n classifier\n")
    if cfg.logging:
        logger.info("model: 1-n-n classifier\n")

    print("AL Query Method: {}\nMax AL Episodes: {}\n".format(cfg.ACTIVE_LEARNING.SAMPLING_FN, cfg.ACTIVE_LEARNING.MAX_ITER))
    if cfg.logging:
        logger.info("AL Query Method: {}\nMax AL Episodes: {}\n".format(cfg.ACTIVE_LEARNING.SAMPLING_FN, cfg.ACTIVE_LEARNING.MAX_ITER))
    
    acc_list = []
    MCC_list = []

    times_all = []
    for cur_episode in range(0, cfg.ACTIVE_LEARNING.MAX_ITER+1):
        print("======== EPISODE {} BEGINS ========\n".format(cur_episode))
        if cfg.logging:
            logger.info("======== EPISODE {} BEGINS ========\n".format(cur_episode))
        # Creating output directory for the episode
        episode_dir = os.path.join(cfg.EXP_DIR, f'episode_{cur_episode}')
        if not os.path.exists(episode_dir):
            os.mkdir(episode_dir)
        cfg.EPISODE_DIR = episode_dir

        if cfg.DATASET.NAME in ['CIFAR10']:
            device='cuda'
            ufeat = torch.tensor(train_data.features[np.array(valSet).astype(int)]).to(device)
            lfeat = torch.tensor(train_data.features[np.array(lSet).astype(int)]).to(device)

            labels = np.array(train_data.targets)
            ufeat_labels = torch.tensor(labels[np.array(valSet).astype(int)]).to(device)
            lfeat_labels = torch.tensor(labels[np.array(lSet).astype(int)]).to(device)

            if cfg.pure_nn:
                test_acc, MCC = one_nn(cfg, ufeat, lfeat, ufeat_labels, lfeat_labels, batch_size=1028, device='cuda')
            else:
                test_acc, MCC = one_nn_silhou(cfg, ufeat, lfeat, ufeat_labels, lfeat_labels, best_gamma, batch_size=1028, device='cuda')
        elif cfg.DATASET.NAME in ['CIFAR10_imbalanced', 'CIFAR10_all_imbalanced',  'TRPB_balanced', 'optdigits', 'phishing']:
            device='cuda'
            ufeat = torch.tensor(test_data).float().to(device)
            # lfeat = torch.tensor(train_data[np.array(lSet).astype(int)]).to(device)
            lfeat = torch.tensor(train_data[np.array(lSet).astype(int)]).float().to(device)
            ufeat_labels = torch.tensor(test_label).to(device)
            lfeat_labels = torch.tensor(train_label[np.array(lSet).astype(int)]).to(device)
            if cfg.pure_nn:
                test_acc, MCC = one_nn(cfg, ufeat, lfeat, ufeat_labels, lfeat_labels, batch_size=1028, device='cuda')
            else:
                test_acc, MCC = one_nn_silhou(cfg, ufeat, lfeat, ufeat_labels, lfeat_labels, best_gamma, batch_size=1028, device='cuda')


    
        print("Test Accuracy: {}.\n".format(round(test_acc, 4)))
        acc_list.append(round(test_acc, 4)*100)
        MCC_list.append(round(MCC, 4)*100)
        if cfg.logging:
            logger.info("EPISODE {} Test Accuracy {}.\n".format(cur_episode, test_acc))


        # Active Sample 
        # print("======== ACTIVE SAMPLING ========\n")
        if cfg.logging:
            logger.info("======== ACTIVE SAMPLING ========\n")
        al_obj = ActiveLearning(data_obj, cfg)


        a=time.time()

        if cfg.ACTIVE_LEARNING.SAMPLING_FN.lower() in ["activesilhouette"]:
            if args.fix_gamma:
                activeSet, new_uSet, best_gamma = al_obj.sample_from_uSet(model, lSet, uSet, train_data, dele_sample, best_gamma)
            else:
                activeSet, new_uSet, best_gamma = al_obj.sample_from_uSet(model, lSet, uSet, train_data, dele_sample)
            print("best gamma: ", best_gamma)
        else:
            activeSet, new_uSet = al_obj.sample_from_uSet(model, lSet, uSet, train_data, dele_sample)
        print("======================== time =====================\n")
        print("running time is:", time.time()-a)
        
        times_all.append(time.time()-a)
        if cfg.logging:
            logger.info(f'Active set is {activeSet}\n')


        # Add activeSet to lSet, save new_uSet as uSet and update dataloader for the next episode
        lSet = np.append(lSet, activeSet)
        uSet = new_uSet

        lSet_loader = data_obj.getSequentialDataLoader(indexes=lSet, batch_size=cfg.TRAIN.BATCH_SIZE, data=train_data)


        print("Active Sampling Complete. After Episode {}:\nNew Labeled Set: {}, New Unlabeled Set: {}, Active Set: {}\n".format(cur_episode, len(lSet), len(uSet), len(activeSet)))
        if cfg.logging:
            logger.info("Active Sampling Complete. After Episode {}:\nNew Labeled Set: {}, New Unlabeled Set: {}, Active Set: {}\n".format(cur_episode, len(lSet), len(uSet), len(activeSet)))
        
        # print("================================\n\n")
        if cfg.logging:
            logger.info("================================\n\n")

        # print('Current accuracy values: ', plot_episode_yvalues)
    print('Test Accuracy: ', acc_list)
    print("All of the selected samples are: ", lSet)
    
    if cfg.DATASET.NAME in ['CIFAR10_imbalanced', 'CIFAR10_all_imbalanced','TRPB']:
        print('MCC(Matthews Correlation Coefficient): ', MCC_list)
    print("times_all: ", times_all)
    print('####################################Finish!################################################')




if __name__ == "__main__":
    args = argparser().parse_args()
    cfg.merge_from_file(args.cfg)
    cfg.EXP_NAME = args.exp_name
    cfg.ACTIVE_LEARNING.SAMPLING_FN = args.al
    cfg.ACTIVE_LEARNING.BUDGET_SIZE = args.budget
    cfg.ACTIVE_LEARNING.INITIAL_DELTA = args.initial_delta
    cfg.RNG_SEED = args.seed
    cfg.MODEL.LINEAR_FROM_FEATURES = args.linear_from_features
    cfg.ACTIVE_LEARNING.A_LOGISTIC = args.a_logistic
    cfg.ACTIVE_LEARNING.K_LOGISTIC = args.k_logistic
    cfg.ACTIVE_LEARNING.RATIO = args.ratio
    cfg.prob = args.prob
    cfg.DATASET.NAME = args.dataset
    cfg.MODEL.NUM_CLASSES = args.num_classes
    cfg.logging = args.logging
    cfg.random_initial = args.random_initial
    cfg.medoid_initial = args.medoid_initial
    cfg.pure_nn = args.pure_nn
    cfg.random_SEED = args.random_seed
    cfg.data_type = args.data_type
    cfg.ACTIVE_LEARNING.MAX_ITER = args.maxiter
    main(cfg)
