import os 
import time 
import json 
import itertools
import random
import glob
import torch
import numpy as np
import argsparse

def get_datapoints(data_ids, data_file_path):
    with open(data_file_path, "r") as f:
        data = json.load(f)
    
    relevant_ids = [i.split("::")[0] for i in data_ids]
    relevant_data = [i for i in data if i['id'] in relevant_ids]
    return relevant_data

def main(args):
    cluster_vals = json.load(open(args.clusters_file_path)) 
    clusters, counts = np.unique(np.array([i for i in cluster_vals.values()]), return_counts=True)

    cluster_dict = {}
    for data_id, cluster_id in cluster_vals.items(): 
        if cluster_id not in cluster_dict.keys():
            cluster_dict[cluster_id] = [data_id]
        else:
            cluster_dict[cluster_id].append(data_id)

    if args.method == 1:
        data_ids = [i for i in cluster_vals.keys()]
        I = [i for i in cluster_vals.values()]
        sorted_idx = np.argsort(counts)
        sorted_idx = sorted_idx[counts[sorted_idx] > 2]
        n = amount
        sorted_idx = [i.item() for i in sorted_idx]
        sorted_idx.reverse()
        sampled_indices = []
        # sample from the largest clusters first
        for i in range(len(sorted_idx)):
            
            n_per_cluster = n // (len(sorted_idx) - i)
            indices = np.where(I == clusters[sorted_idx[i]])[0]
            
            if len(indices) > n_per_cluster:
                sampled_indices.extend(np.random.choice(indices, n_per_cluster, replace=False))
                n -= n_per_cluster
                
            else:
                sampled_indices.extend(indices)
                n -= len(indices)
        
        if n > 0:
            clusters_to_sample = clusters[np.where(counts <= 2)[0]]
            indices = np.where(np.isin(I, clusters_to_sample))[0]
            sampled_indices.append(np.random.choice(indices, n, replace=False))
                
        # new_indices = np.concatenate(sampled_indices)
        sampled_data_ids = [data_ids[i] for i in sampled_indices]
        data_ids = sampled_data_ids

    elif args.method == 2:
        sorted_clusters = sorted([(k, v) for k,v in zip(clusters, counts)], key=lambda x: x[1])
        cum_sum_counts = list(itertools.accumulate([i[1] for i in sorted_clusters]))
        cluster_cum_sum = [(sorted_clusters[idx][0], cum_sum) for idx, cum_sum in enumerate(cum_sum_counts)]

        data_ids = []
        if amount in cum_sum_counts:
            full_clusters = [i[0] for i in cluster_cum_sum if i[1] <= amount]
            for cluster in full_clusters:
                data_ids.extend(cluster_dict[cluster])
        else:
            full_clusters = [i for i in cluster_cum_sum if i[1] < amount]
            for cluster in full_clusters:
                data_ids.extend(cluster_dict[cluster[0]])

            partial_cluster = cluster_cum_sum[len(full_clusters)][0]
            num_to_sample = amount - full_clusters[-1][1] 
            partial_data_ids = random.sample(cluster_dict[partial_cluster], num_to_sample)
            data_ids.extend(partial_data_ids)

    elif args.method == 3:
        if args.metric_file_path is None:
            raise ValueError("If args.method == 3, then args.metric_file_path cannot be None")

        datapoint_metric_dict = json.load(open(args.metric_file_path, "r"))
        metrics = {}
        for cluster_id, datapoint_ids in cluster_dict.items():
            metrics[cluster_id] = {
                "cluster_mean": np.mean(np.array([datapoint_metric_dict[i] for i in datapoint_ids])),
                "cluster_std": np.std(np.array([datapoint_metric_dict[i] for i in datapoints_ids])),
            }

        clusters_metrics = [(id, metrics['cluster_mean'], metrics['cluster_std']) for id, metrics in metric_dict.items()]

        cluster_means_mean = np.mean(np.array([i[1] for i in clusters_metrics]))
        cluster_means_std = np.std(np.array([i[1] for i in clusters_metrics]))
        cluster_probs_unnorm = np.array([((i[1] - cluster_means_mean) / cluster_means_std) for i in clusters_metrics])
        cluster_probs_norm = np.exp(cluster_probs_unnorm)
        cluster_probs = [(clusters_metrics[i][0], (cluster_probs_norm[i]/cluster_probs_norm.sum())) for i in range(len(clusters_metrics))]

        data_ids = set()
        while len(data_ids) != amount:
            print(len(data_ids))
            empty_clusters = []

            sampled_cluster = random.choices([i[0] for i in cluster_probs], weights=[i[1] for i in cluster_probs], k=1)
            
            cluster_data_ids = cluster_dict[sampled_cluster[0]]
            remaining_data_ids = list(set(cluster_data_ids) - set(data_ids))
            
            if not remaining_data_ids:
                empty_clusters.append(sampled_cluster[0])
                unnorm_cluster_probs = [i for i in cluster_probs if i[0] not in empty_clusters]
                partition_func = sum(unnorm_cluster_probs)
                new_norm_cluster_probs = [i/partition_func for i in unnorm_cluster_probs]

                new_sampled_cluster = random.choices([i[0] for i in cluster_probs], weights=[i[1] for i in cluster_probs], k=1)
                new_cluster_data_ids = cluster_dict[sampled_cluster[0]]
                remaining_data_ids = list(set(new_cluster_data_ids) - set(data_ids))

            if len(remaining_data_ids) == 1:
                data_ids.add(selected_data_id[0])
                continue

            remaining_data_ids_metrics = np.array([datapoint_metric_dict[i] for i in remaining_data_ids])
            
            remaining_cluster_mean = np.mean(remaining_data_ids_metrics)
            remaining_cluster_std = np.std(remaining_data_ids_metrics)
            
            unnormalized_probs = np.array([((datapoint_metric_dict[i] - remaining_cluster_mean)/remaining_cluster_std) for i in remaining_data_ids])
            normalized_probs = np.exp(unnormalized_probs)
            
            selected_data_id = random.choices(remaining_data_ids, weights=[(i/normalized_probs.sum()) for i in normalized_probs], k=1)
            
            data_ids.add(selected_data_id[0])
        
        data_ids = list(data_ids)            

    elif args.method == 4:
        if args.msa_embeds_dir is None:
            raise ValueError("If args.method == 4, then args.msa_embeds_dir cannot be None")

        feature_files = [i for i in glob.glob(os.path.join(args.msa_embeds_dir, "last_layer_msa_embeds_") + "*.json")]

        print("Loading Features!!")
        msa_embeds_dict = {}
        for file in feature_files: 
            msa_embeds = json.load(open(file))
            for k,v in msa_embeds.items():
                msa_embeds_dict[k] = v 

        combined_tensor = torch.stack([torch.tensor(i) for i in msa_embeds_dict.values()])

        def feature_hdbscan(ids, n, args):
            features = [torch.tensor(msa_embeds_dict[id]) for id in ids]
            features = torch.stack(features)
            features[torch.isnan(features)] = 0
            print("this is the shape of the features: ", features.shape)

            import hdbscan 

            def compute_hdbscan_params(n_samples, fraction=0.01, min_floor=2):
                min_cluster_size = max(min_floor, int(n_samples * fraction))
                min_samples = min_cluster_size  
                return min_cluster_size, min_samples
            
            num_samples = len(ids)
            
            min_cluster_size, min_samples = compute_hdbscan_params(num_samples)

            clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, metric='euclidean')
            clusterer.fit(features)
            num_clusters = clusterer.labels_.max() + 1

            cluster_frac = num_clusters / num_samples 

            persistences = clusterer.cluster_persistence_
            avg_persistence = persistences.mean()

            labels = clusterer.labels_
            noise_points = np.sum(labels == -1)
            noise_ratio = noise_points / num_samples

            attn_stability_vals = json.load(open(args.data_sel_metric_file_path))

            def blob_select(ids, attn_stability_vals):
                attn_stability_list = []
                for id in ids:
                    attn_stability_list.append({
                        "id": id, 
                        "attn_vals": attn_stabiltiy_vals[id]
                    })
                return sorted(attn_stability_list, key=lambda x: x["attn_vals"], reverse=False)

            sampled_ids = []

            # ineffective sampling only possible in the case of samll outer clusters 
            if num_samples < 350:   
                # TODO: decide values for threshold1 and threshold2
                if cluster_frac == 1/min_cluster_size or avg_persistence < threshold1 or noise_ratio > threshold2:  
                    # ineffective clustering -> discard clusters and select the desired points from that blob completely
                    sorted_attn_stability_list = blob_select(ids, attn_stability_vals)
                    samples_ids.extend([i['id'] for i in sorted_attn_stability_list[:n_per_cluster]])

            # meaningful clustering 
            clusters, counts = np.unique(labels, return_counts=True)
            sorted_idx = np.argsort(counts)
            sorted_idx = sorted_idx[counts[sorted_idx] > 2]
                
            # sample from the largest clusters first
            for i in range(len(sorted_idx)):
                n_per_cluster = n // (len(sorted_idx) - i)
                indices = np.where(labels == clusters[sorted_idx[i]])[0]
                if len(indices) > n_per_cluster:
                    relevant_ids = [ids[i] for i in indices]
                    sorted_attn_stability_list = blob_select(relevant_ids, attn_stability_vals)
                    samples_indices.extend(sorted_attn_stability_list[:n_per_cluster])
                    n -= n_per_cluster
                else:
                    sampled_ids.extend([ids[i] for i in indices])
                    n -= len(indices)
            
            # I think it will better that instead of removing the clusters that are small in size you remove the clusters 
            if n > 0:
                if len(clusters[np.where(counts <= 2)[0]]) != 0:
                    clusters_to_sample = clusters[np.where(counts <= 2)[0]]
                    indices = np.where(np.isin(I, clusters_to_sample))[0]
                    relevant_ids = [ids[i] for i in indices]
                    sorted_attn_stability_list = blob_select(relevant_ids, attn_stability_vals)
                    samples_indices.extend(sorted_attn_stability_list[:n_per_cluster])
                    # sampled_indices.extend(np.random.choice(indices, n, replace=False))

        def feature_kmeans(ids, n):
            features = [torch.tensor(msa_embeds_dict[id]) for id in ids]
            features = torch.stack(features)
            features[torch.isnan(features)] = 0

            import faiss
            start_time = time.time()
            kmeans = faiss.Kmeans(features.shape[1], 10, niter=100, verbose=True)
            kmeans.train(features.numpy())
            
            # get the kmeans cluster labels
            D, I = kmeans.index.search(features.numpy(), 1)
            
            clusters, counts = np.unique(I, return_counts=True)
            sorted_idx = np.argsort(counts)
            sorted_idx = sorted_idx[counts[sorted_idx] > 2]
                
            sampled_indices = []
            # sample from the largest clusters first
            for i in range(len(sorted_idx)):
                n_per_cluster = n // (len(sorted_idx) - i)
                indices = np.where(I == clusters[sorted_idx[i]])[0]
                if len(indices) > n_per_cluster:
                    sampled_indices.extend(np.random.choice(indices, n_per_cluster, replace=False))
                    n -= n_per_cluster
                else:
                    sampled_indices.extend(indices)
                    n -= len(indices)
                    
            if n > 0:
                clusters_to_sample = clusters[np.where(counts <= 2)[0]]
                indices = np.where(np.isin(I, clusters_to_sample))[0]
                sampled_indices.extend(np.random.choice(indices, n, replace=False))

            selected_ids = [ids[i] for i in sampled_indices]
            return selected_ids

        data_ids = [i for i in cluster_vals.keys()]
        I = [i for i in cluster_vals.values()]
        sorted_idx = np.argsort(counts)
        sorted_idx = sorted_idx[counts[sorted_idx] > 3]
        n = amount
        sorted_idx = [i.item() for i in sorted_idx]
        sorted_idx.reverse()
        sampled_ids = []
        # sample from the largest clusters first
        for i in range(len(sorted_idx)):
            print("# datapoints sampled: ", len(sampled_ids))
            n_per_cluster = n // (len(sorted_idx) - i)
            sampling_cluster = clusters[sorted_idx[i]]
            cluster_points = cluster_dict[sampling_cluster]
            
            if n_per_cluster >= len(cluster_points):
                selected_points = cluster_points 
            else:
                selected_points = feature_kmeans(cluster_points, n_per_cluster)

            sampled_ids.extend(selected_points)
            n -= n_per_cluster
        data_ids = sampled_ids 

datapoints = get_datapoints(sampled_ids, args.data_file_path)
print("Number of sampled datapoints: ", len(datapoints))

with open(args.save_path, "w") as f:
    json.dump(datapoints, f, indent=4)

print("File saved at: ", args.save_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=int, required=True)
    parser.add_argument('--clusters_file_path', type=str, required=False)
    parser.add_argument('--metric_file_path', type=str, required=False)
    parser.add_argument('--msa_embeds_dir', type=str, required=False)
    parser.add_argument('--data_sel_metric_file_path', type=str, required=False)
    parser.add_argument('--data_file_path', type=str, required=True)
    parser.add_argument('--save_path', type=str, required=True)

    args = parser.parse_args()

    main(args)