import numpy as np
import argparse
import os
# import hdbscan
from cuml.cluster import HDBSCAN, KMeans
from sklearn.cluster import KMeans
import random
from tqdm import tqdm
import h5py
import scipy.io as sio

from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))


def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--method', type=str, 
                        default='kmeans', 
                        choices=['hdbscan', 'kmeans'],
                        help='What clustering method to use')
    parser.add_argument('--train-dir', type=str, 
                        default='../datasets/flowers/sam_seg_compressed_feats', 
                        help='input dir')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/kmeans_flowers', 
                        help='exp dir')
    parser.add_argument('--class-names-filepath', type=str, 
                        default=None,
                        help='class names filepath')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='number of examples to use per class')
    
    # kmeans
    parser.add_argument('--n-clusters', type=int, 
                        default=100, 
                        help='number of clusters')
    parser.add_argument('--n-init', type=int, 
                        default=1, 
                        help='number of inits')
    
    # hdbscan
    parser.add_argument('--min-cluster-size', type=int, 
                        default=60, 
                        help='min cluster size')
    parser.add_argument('--min-samples', type=int, 
                        default=5, 
                        help='min samples')
    parser.add_argument('--cluster-selection-epsilon', type=float, 
                        default=0.5, 
                        help='cluster selection epsilon')
    parser.add_argument('--alpha', type=int, 
                        default=1, 
                        help='alpha. the higher the tighter the stricter')
    parser.add_argument('--cluster-selection-method', type=str, 
                        default='leaf', 
                        help='min samples')
    

    # misc
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    if args.seed != -1:
        # Torch RNG
        # torch.manual_seed(args.seed)
        # torch.cuda.manual_seed(args.seed)
        # torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    os.makedirs(args.exp_dir, exist_ok=True)


    class_wnid_filename = 'data/imagenet_10_classes/wnids.txt'
    with open(class_wnid_filename, 'rt') as input_file:
        wnids_10 = [wnid.strip() for wnid in input_file.readlines()]

    X = []
    segment_paths = []
    segment_idxs = []
    
    count = 0
    if args.class_names_filepath is None:
        dirnames = list(sorted(os.listdir(args.train_dir)))
    else:
        with open(args.class_names_filepath, 'rt') as input_file:
            dirnames = [line.strip()[1:] for line in input_file.readlines()]
    for dirname in tqdm(dirnames):
        src_cls_dir = os.path.join(args.train_dir, dirname)

        for filename in tqdm(os.listdir(src_cls_dir)):
            segment_path = os.path.join(src_cls_dir, filename)
            segments = np.load(segment_path)
            X.append(segments)
            segment_paths.extend([segment_path] * len(segments))
            segment_idxs.extend(list(range(len(segments))))
            count += 1
            if args.train_size != -1 and count > args.train_size:
                break
        # if count > args.n_clusters * 2:
        #     break

    X = np.concatenate(X, axis=0)
    print(X.shape)
    

    if args.method == 'kmeans':
        print('Start kmeans ...')
        kmeans = KMeans(n_clusters=args.n_clusters, 
                        random_state=args.seed, 
                        n_init=args.n_init,
                        verbose=6).fit(X)
        y = kmeans.labels_.tolist()
        print('Kmeans done.')
        print('Save segment paths ...')
        with open(os.path.join(args.exp_dir, 'segment_paths.txt'), 'wt') as output_file:
            for i, line in enumerate(segment_paths):
                output_file.write(f'{line}\t{segment_idxs[i]}\n')
        print('Save cluster labels ...')
        with open(os.path.join(args.exp_dir, 'cluster_labels.txt'), 'wt') as output_file:
            for label in y:
                output_file.write(f'{label}\n')
        # with open(os.path.join(args.exp_dir, 'cluster_segment_labels.txt'), 'wt') as output_file:
        #     header = 'image_path\tsegment_idx\tcluster_label\n'
        #     output_file.write(header)
        #     for i, line in enumerate(segment_paths):
        #         image_filename = os.path.basename(line).replace('.h5.npy', '')
        #         output_file.write(f'{image_filename}\t{segment_idxs[i]}\t{y[i]}\n')
        print('Save_centers ...')
        np.save(os.path.join(args.exp_dir, 'kmeans.pkl'), kmeans)
    elif args.method == 'hdbscan':
        print('Start hdbscan ...')
        clusterer = HDBSCAN(min_cluster_size=args.min_cluster_size,
                            min_samples=args.min_samples,
                            cluster_selection_epsilon=args.cluster_selection_epsilon,
                            alpha=args.alpha,
                            cluster_selection_method=args.cluster_selection_method,
                            verbose=6)
        clusterer.fit(X)
        y = clusterer.labels_.tolist()
        print('HDBSCAN done.')
        print('Save segment paths ...')
        with open(os.path.join(args.exp_dir, 'segment_paths.txt'), 'wt') as output_file:
            for i, line in tqdm(enumerate(segment_paths)):
                output_file.write(f'{line}\t{segment_idxs[i]}\n')
        print('Save cluster labels ...')
        with open(os.path.join(args.exp_dir, 'cluster_labels.txt'), 'wt') as output_file:
            for label in tqdm(y):
                output_file.write(f'{label}\n')
        # print('Save_centers ...')
        # np.save(os.path.join(args.exp_dir, 'hdbscan.pkl'), clusterer)
    else:
        raise ValueError('Unknown method')
    
    exp_dir = os.path.basename(args.exp_dir)

    labels_filename = '../datasets/flowers/imagelabels.mat'
    imagelabels = sio.loadmat(labels_filename)['labels'].reshape(-1).tolist()

    with open(f'exps/{exp_dir}/cluster_labels.txt', 'r') as input_file:
        cluster_labels = input_file.readlines()
        cluster_labels = [int(label.strip()) for label in cluster_labels]
    print(sum([1 for label in cluster_labels if label != -1]))

    print('Loading segment paths...')
    with open(f'exps/{exp_dir}/segment_paths.txt', 'r') as input_file:
        segment_paths = []
        segment_idx = []
        class_names = []
        for line in tqdm(input_file.readlines()):
            line_list = line.strip().split('\t')
            filedir = os.path.dirname(line_list[0])
            filename = os.path.basename(line_list[0]).replace('.npy', '')
            file_dirname = filedir.replace(args.train_dir, '')
            if file_dirname.startswith('/'):
                file_dirname = file_dirname[1:]
            segment_paths.append(os.path.join(file_dirname, filename))
            segment_idx.append(int(line_list[1]))
            class_names.append(file_dirname)

    print('Saving image paths with cluster labels...')
    with open(f'exps/{exp_dir}/cluster_labels_segment.tsv', 'w') as output_file:
        header = 'image_path\tsegment_idx\tcluster_label\tclass_label\n'
        output_file.write(header)
        for i in tqdm(range(len(segment_paths))):
            write_str = f'{segment_paths[i]}\t{segment_idx[i]}\t{cluster_labels[i]}\t{class_names[i]}\n'
            output_file.write(write_str)

    

if __name__ == '__main__':
    main()
