from __future__ import print_function
import nni, os, sys, numpy
import argparse
import warnings
import torch

sys.path.append('../src')
os.chdir("../src")
from utils import *
from dataset.dataset import *
from clustering.algorithms import *

warnings.simplefilter(action='ignore', category=UserWarning)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Training')
    parser.add_argument('--exp_group', type=str, default="")
    parser.add_argument('--task', type=str, default="")
    parser.add_argument('--seed', type=int, default=1314, metavar='S',
                        help='random seed (default: 1)')    
    parser.add_argument('--kfold_split', type = int, default=5, 
                        help='number of folds used to validate')
    parser.add_argument('--pmodel', type=str, default="",
                        help='name of parent model set')
    parser.add_argument('--cmodel', type=str, default="",
                        help='name of child model set')
    parser.add_argument('--method', type=str, default="",
                        help='name of clustering method')

    parser.add_argument('--z', type=str, default="",
                        help='z distribution used in clustering')
    
    # DBSCAN
    parser.add_argument('--eps', type=float, default=0.5,
                        help='DBSCAN eps')
    # KMeans 
    parser.add_argument('--tol', type=float, default=1e-5,
                        help='KMeans tol')
    # KMeans # MeanShift
    parser.add_argument('--max_iter', type=int, default=100,
                        help='KMeans max_iter')
    # MeanShift
    parser.add_argument('--bandwidth', type=float, default=1.0,    
                        help='MeanShift bandwidth')
    parser.add_argument('--centroid_merge_threshold', type=float, default=1e-3,   
                        help='MeanShift centroid_merge_threshold') 
    # KMeans Phylogeny & MeanShift Phylogeny
    parser.add_argument('--alpha', type=float, default=1.0,
                        help='Phylogeny alpha')
    # KMeans Phylogeny                        
    parser.add_argument('--pcthreshold', type=float, default=1.0,
                        help='KMeans Phylogeny pcthreshold')
    
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    randomness_control(args.seed)

    exp_specify = f"{args.pmodel}__{args.cmodel}/{args.method}/{args.z}_{args.seed}"
    if args.method == "DBSCAN":
        exp_specify += f"_{args.eps}"
    elif args.method == "GMM": 
        pass       
    elif args.method == "KMeans":
        exp_specify += f"_{args.tol}_{args.max_iter}"
    elif args.method == "MeanShift":
        exp_specify += f"_{args.bandwidth}_{args.centroid_merge_threshold}_{args.max_iter}"
    elif args.method == "KMeans_Phylogeny":
        exp_specify += f"_{args.tol}_{args.max_iter}_{args.alpha}_{args.pcthreshold}"
    elif args.method == "MeanShift_Phylogeny":
        assert args.centroid_merge_threshold == 1e-3
        exp_specify += f"_{args.bandwidth}_{args.max_iter}_{args.alpha}"
    else:
        raise ValueError(f"method {args.method} is not supported")

    
    args.exp_id = f"./log/"+exp_specify
    os.makedirs(args.exp_id, exist_ok = True)

    logger, formatter = get_logger(args.exp_id, None, "log.log", level=logging.INFO)

    phydataset = PhylogenyDataset(args.pmodel, args.cmodel)
    weight_p, weight_c, labels = phydataset.complete_dataset
    weight_p, weight_c, labels = weight_p.to(device), weight_c.to(device), labels.to(device)

    args.n_clusters = phydataset.num_p

    clalg = get_alg(args)

    with torch.no_grad():
        prediction = clalg.forward(torch.cat((weight_p, weight_c),dim = 0))

        num_classes = prediction.shape[0]
        prediction[prediction == -1] = num_classes - 1
        one_hot_prediction = torch.nn.functional.one_hot(prediction, num_classes=num_classes).T

        logger.info(f"Clustering process was finished")
        
        return one_hot_prediction

if __name__ == '__main__':
    main()
