import numpy as np
import argparse
import os
# import hdbscan
from cuml.cluster import HDBSCAN, KMeans
from sklearn.cluster import KMeans
import random
from tqdm.auto import tqdm
import h5py
import scipy.io as sio

from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))


def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--method', type=str, 
                        default='kmeans', 
                        choices=['hdbscan', 'kmeans'],
                        help='What clustering method to use')
    parser.add_argument('--train-dir', type=str, 
                        default='', 
                        help='input dir')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/kmeans_flowers', 
                        help='exp dir')
    parser.add_argument('--class-names-filepath', type=str, 
                        default=None,
                        help='class names filepath')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='number of examples to use per class')
    
    # kmeans
    # parser.add_argument('--n-clusters', type=int, 
    #                     default=100, 
    #                     help='number of clusters')
    parser.add_argument('--n-init', type=int, 
                        default=1, 
                        help='number of inits')
    
    # hdbscan
    parser.add_argument('--min-cluster-size', type=int, 
                        default=60, 
                        help='min cluster size')
    parser.add_argument('--min-samples', type=int, 
                        default=5, 
                        help='min samples')
    parser.add_argument('--cluster-selection-epsilon', type=float, 
                        default=0.5, 
                        help='cluster selection epsilon')
    parser.add_argument('--alpha', type=int, 
                        default=1, 
                        help='alpha. the higher the tighter the stricter')
    parser.add_argument('--cluster-selection-method', type=str, 
                        default='leaf', 
                        help='min samples')
    

    # misc
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    if args.seed != -1:
        # Torch RNG
        # torch.manual_seed(args.seed)
        # torch.cuda.manual_seed(args.seed)
        # torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    n_cluster_range = list(range(30, 330, 30))

    for num_clusters in tqdm(n_cluster_range):
        exp_dir = args.exp_dir + f'_cl{num_clusters}'

        os.makedirs(exp_dir, exist_ok=True)

        class_wnid_filename = 'data/imagenet_10_classes/wnids.txt'
        with open(class_wnid_filename, 'rt') as input_file:
            wnids_10 = [wnid.strip() for wnid in input_file.readlines()]

        X = []
        segment_paths = []
        segment_idxs = []
        
        count = 0
        if args.class_names_filepath is None:
            dirnames = list(sorted(os.listdir(args.train_dir)))
        else:
            with open(args.class_names_filepath, 'rt') as input_file:
                dirnames = [line.strip()[1:] for line in input_file.readlines()]
        for dirname in dirnames:
            src_cls_dir = os.path.join(args.train_dir, dirname)

            for filename in tqdm(os.listdir(src_cls_dir)):
                segment_path = os.path.join(src_cls_dir, filename)
                segments = np.load(segment_path)
                X.append(segments)
                segment_paths.extend([segment_path] * len(segments))
                segment_idxs.extend(list(range(len(segments))))
                count += 1
                if args.train_size != -1 and count > args.train_size:
                    break
            # if count > args.n_clusters * 2:
            #     break

        X = np.concatenate(X, axis=0)
        print(X.shape)
        

        if args.method == 'kmeans':
            print('Start kmeans ...')
            kmeans = KMeans(n_clusters=num_clusters, 
                            random_state=args.seed, 
                            n_init=args.n_init,
                            verbose=6).fit(X)
            y = kmeans.labels_.tolist()
            print('Kmeans done.')
            print('Save segment paths ...')
            with open(os.path.join(exp_dir, 'segment_paths.txt'), 'wt') as output_file:
                for i, line in enumerate(segment_paths):
                    output_file.write(f'{line}\t{segment_idxs[i]}\n')
            print('Save cluster labels ...')
            with open(os.path.join(exp_dir, 'cluster_labels.txt'), 'wt') as output_file:
                for label in y:
                    output_file.write(f'{label}\n')
            # with open(os.path.join(exp_dir, 'cluster_segment_labels.txt'), 'wt') as output_file:
            #     header = 'image_path\tsegment_idx\tcluster_label\n'
            #     output_file.write(header)
            #     for i, line in enumerate(segment_paths):
            #         image_filename = os.path.basename(line).replace('.h5.npy', '')
            #         output_file.write(f'{image_filename}\t{segment_idxs[i]}\t{y[i]}\n')
            print('Save_centers ...')
            np.save(os.path.join(exp_dir, 'kmeans.pkl'), kmeans)
        elif args.method == 'hdbscan':
            print('Start hdbscan ...')
            clusterer = HDBSCAN(min_cluster_size=args.min_cluster_size,
                                min_samples=args.min_samples,
                                cluster_selection_epsilon=args.cluster_selection_epsilon,
                                alpha=args.alpha,
                                cluster_selection_method=args.cluster_selection_method,
                                verbose=6)
            clusterer.fit(X)
            y = clusterer.labels_.tolist()
            print('HDBSCAN done.')
            print('Save segment paths ...')
            with open(os.path.join(exp_dir, 'segment_paths.txt'), 'wt') as output_file:
                for i, line in tqdm(enumerate(segment_paths)):
                    output_file.write(f'{line}\t{segment_idxs[i]}\n')
            print('Save cluster labels ...')
            with open(os.path.join(exp_dir, 'cluster_labels.txt'), 'wt') as output_file:
                for label in tqdm(y):
                    output_file.write(f'{label}\n')
            # print('Save_centers ...')
            # np.save(os.path.join(args.exp_dir, 'hdbscan.pkl'), clusterer)
        else:
            raise ValueError('Unknown method')
        
        # exp_dir = os.path.basename(exp_dir)

        labels_filename = '/shared_data0/weiqiuy/datasets/flowers/imagelabels.mat'
        imagelabels = sio.loadmat(labels_filename)['labels'].reshape(-1).tolist()

        with open(f'{exp_dir}/cluster_labels.txt', 'r') as input_file:
            cluster_labels = input_file.readlines()
            cluster_labels = [int(label.strip()) for label in cluster_labels]
        print(sum([1 for label in cluster_labels if label != -1]))

        print('Loading segment paths...')
        with open(f'{exp_dir}/segment_paths.txt', 'r') as input_file:
            segment_paths = []
            segment_idx = []
            class_names = []
            for line in tqdm(input_file.readlines()):
                line_list = line.strip().split('\t')
                filedir = os.path.dirname(line_list[0])
                filename = os.path.basename(line_list[0]).replace('.npy', '')
                file_dirname = filedir.replace(args.train_dir, '')
                if file_dirname.startswith('/'):
                    file_dirname = file_dirname[1:]
                segment_paths.append(os.path.join(file_dirname, filename))
                segment_idx.append(int(line_list[1]))
                class_names.append(file_dirname)

        print('Saving image paths with cluster labels...')
        with open(f'{exp_dir}/cluster_labels_segment.tsv', 'w') as output_file:
            header = 'image_path\tsegment_idx\tcluster_label\tclass_label\n'
            output_file.write(header)
            for i in tqdm(range(len(segment_paths))):
                write_str = f'{segment_paths[i]}\t{segment_idx[i]}\t{cluster_labels[i]}\t{class_names[i]}\n'
                output_file.write(write_str)

    

if __name__ == '__main__':
    main()
