import argparse
import os
import csv
import pickle
import numpy as np
from tqdm import tqdm
from sklearn.cluster import SpectralClustering, KMeans
import torch
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.celebA_dataset import CelebADataset
from wilds.datasets.coco_places_dataset import COCOonPlacesDataset
from examples.utils.print_logger import get_logger
LOGGER = get_logger(__name__, level="DEBUG")
IGNORE_LABEL = 255

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default='waterbirds', help='')
    parser.add_argument('--root_dir', required=True,
                        help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
    parser.add_argument('--env_dir', type=str)
    parser.add_argument("--layers", nargs='+', default=['features.5'], help="Layers to select for clustering")
    parser.add_argument('--clustering_algo', type=str, choices=['kmeans','spectral'])
    parser.add_argument('--n_clusters', type=int, default=2)
    parser.add_argument("--random_proj", type=int, default=0, help="Normalise gram matrix and project")
    parser.add_argument("--pred_val", action="store_true", default=False, help="Pred environments also on val set")
    parser.add_argument("--ignore_labels", action="store_true", default=False, help="Ignoring test labels")
    config = parser.parse_args()

    if config.dataset == 'waterbirds':
        LOGGER.info('Loading Waterbirds dataset')
        dataset = WaterbirdsDataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'celebA':
        LOGGER.info('Loading CelebA dataset')
        dataset = CelebADataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'coco_on_places':
        LOGGER.info('Loading COCO-on-Places dataset')
        dataset = COCOonPlacesDataset(root_dir=config.root_dir, get_img_idx=True)
    else:
        raise ValueError(f"Dataset {config.dataset} not recognized")

    train_vect_cluster_list = [f for f in os.listdir(config.env_dir)
                               if any(l in f for l in config.layers) and 'train' in f]
    train_ind_list = [int(line.split('\n')[0]) for line in open(os.path.join(config.env_dir, 'train_indices.txt'))]
    if config.pred_val:
        val_vect_cluster_list = [f for f in os.listdir(config.env_dir)
                                   if any(l in f for l in config.layers) and 'val' in f]
        val_ind_list = [int(line.split('\n')[0]) for line in open(os.path.join(config.env_dir, 'val_indices.txt'))]

    # Random projection to reduce dimension
    if config.random_proj:
        project = {}
        for vect_gram_file in tqdm(sorted(train_vect_cluster_list)):
            shape = torch.load(os.path.join(config.env_dir, vect_gram_file)).shape
            k0 = int(config.random_proj * np.log(shape[0]))
            A = torch.rand(shape[1], k0) < .5
            A = 2. * A.float() - 1.
            print('Create projection matrix of size', [elem for elem in A.size()])
            project[vect_gram_file] = A / np.sqrt(k0)
        print([project[k].shape for k in project.keys()])

    # Kmeans clustering...
    if config.clustering_algo == 'kmeans':
        LOGGER.info(f'Applying KMeans clustering on features with {config.n_clusters} clusters')
        vector4kmeans = []
        print(train_vect_cluster_list)
        for vect_cluster_file in tqdm(sorted(train_vect_cluster_list)):
            vect_cluster = torch.load(os.path.join(config.env_dir, vect_cluster_file))

            # Random projection to reduce dimension
            if config.random_proj:
                vect_cluster_norm = torch.nn.functional.normalize(vect_cluster, dim=1)
                del vect_cluster
                vect_cluster_norm = vect_cluster_norm @ project[vect_cluster_file]
            else:
                vect_cluster_norm = vect_cluster
                del vect_cluster
            vector4kmeans.append(vect_cluster_norm)
        vector4kmeans = torch.cat(vector4kmeans, dim=1)

        clustering = KMeans(n_clusters=config.n_clusters).fit(vector4kmeans.cpu())
        print('Number of training labels', clustering.labels_.shape[0])

        # Saving model
        with open(os.path.join(config.env_dir,'kmeans_clustering_model.pkl'), 'wb') as f:
            pickle.dump(clustering, f)

        if config.pred_val:
            val_vector4kmeans = []
            for vect_cluster_file in tqdm(sorted(val_vect_cluster_list)):
                vect_cluster = torch.load(os.path.join(config.env_dir, vect_cluster_file))

                # Random projection to reduce dimension
                if config.random_proj:
                    vect_gram_norm = torch.nn.functional.normalize(vect_cluster, dim=1)
                    del vect_cluster
                    key = 'train' + vect_cluster_file[3:]
                    vect_gram_norm = vect_gram_norm @ project[key]
                else:
                    vect_gram_norm = vect_cluster
                    del vect_cluster
                val_vector4kmeans.append(vect_gram_norm)
            val_vector4kmeans = torch.cat(val_vector4kmeans, dim=1)
            LOGGER.warning('Predicting also on val set')
            val_labels = clustering.predict(val_vector4kmeans)
            print('Number of validation labels', val_labels.shape[0])

    # ... or spectral clustering
    if config.clustering_algo=='spectral':
        LOGGER.info(f'Computing similarity matrix from gram matrices')
        results = {'layer': [], 'min': [], 'max': [], 'mean': [], 'median': [], 'std': []}
        len_samples = len(train_ind_list)
        similarity_matrix = torch.zeros((len_samples, len_samples))
        for vect_gram_file in tqdm(sorted(train_vect_cluster_list)):
            vect_gram = torch.load(os.path.join(config.env_dir, vect_gram_file))
            # Cosine similarity per vector
            vect_gram_norm = vect_gram / vect_gram.norm(dim=1)[:, None]
            layerwise_similarity = torch.matmul(vect_gram_norm, vect_gram_norm.T)
            similarity_matrix += layerwise_similarity

            # Save results
            results['layer'].append(vect_gram_file[10:-3])
            results['min'].append(layerwise_similarity.min().item())
            results['max'].append(layerwise_similarity.max().item())
            results['mean'].append(layerwise_similarity.mean().item())
            results['median'].append(layerwise_similarity.median().item())
            results['std'].append(layerwise_similarity.std().item())
        similarity_matrix = similarity_matrix.numpy()

        LOGGER.info(f'Saving similarity stats per layer in {os.path.join(config.env_dir, "layer_stats.csv")}')
        with open(os.path.join(config.env_dir, 'layer_stats.csv'), "w") as outfile:
            writer = csv.writer(outfile)
            writer.writerow(results.keys())
            writer.writerows(zip(*results.values()))

        LOGGER.info(f'Applying spectral clustering based on similarity matrix with {config.n_clusters} clusters')
        clustering = SpectralClustering(n_clusters=config.n_clusters, affinity="precomputed").fit(similarity_matrix)
        # Saving model
        with open(os.path.join(config.env_dir,'spectral_clustering_model.pkl'), 'wb') as f:
            pickle.dump(clustering, f)

    LOGGER.info('Saving predicted environments for each image')
    output_inv = []
    train_cluster_index, val_cluster_index = 0,0
    for idx in tqdm(range(len(dataset._input_array))):
        if idx in train_ind_list:
            output_inv.append([idx,dataset._input_array[idx],clustering.labels_[train_cluster_index]])
            train_cluster_index += 1
        elif config.pred_val and (idx in val_ind_list):
            output_inv.append([idx,dataset._input_array[idx],val_labels[val_cluster_index]])
            val_cluster_index += 1
        else:
            if config.ignore_labels:
                output_inv.append([idx, dataset._input_array[idx], IGNORE_LABEL])
            else:
                output_inv.append([idx, dataset._input_array[idx], dataset.metadata_array[idx,0].item()])

    if config.pred_val:
        env_file_path = os.path.join(config.env_dir,f'env_labels_{config.clustering_algo}_layer28_predval.csv')
    else:
        env_file_path = os.path.join(config.env_dir,f'env_labels_{config.clustering_algo}_layer28.csv')
    with open(env_file_path, mode='w') as csv_file:
        csv_writer = csv.writer(csv_file, delimiter=',')
        csv_writer.writerow(['img_id','img_filename','env'])
        csv_writer.writerows(output_inv)


if __name__ == "__main__":
    main()

