import os
from datetime import datetime
from omegaconf import OmegaConf
import pickle

import numpy as np
import matplotlib.pyplot as plt

from source.experiments.eval_utils import load_concept_activation
from source.experiments.visualize_clusters import show_train_patches
from source.data.imagenet import create_dataset

MODEL_ID_DICT = {'vit_base_patch16_224.augreg_in1k':'FS',
            'vit_base_patch16_224.dino':'DINO',
            'vit_base_patch16_clip_224.openai':'CLIP',
            'vit_base_patch16_224.mae':'MAE',
            #'vit_base_patch16_224.augreg_in21k':'ViTin21k'
            }

ASSIGNMENT_DICT = {'hdbscan':'hdbscan',
                   'mcd': 'projection',
                   'pca': 'projection',
                   'kmeans': 'centroid_distance',
                   }

def main(cfg):

    config_path = [file_path[0] for file_path in os.walk(os.path.join(cfg.exp_dir,str(cfg.run_id),'job_results')) if 'config.yaml' in file_path[2]][0]
    result_path = [file_path[0] for file_path in os.walk(os.path.join(cfg.exp_dir,str(cfg.run_id),'job_results','clustering','results')) if 'sample_idx.npy' in file_path[2]][0]
    print('run config path', config_path)
    print('run result path', result_path)

    #get samples idx from dataset creation
    config_file = os.path.join(config_path,'config.yaml')
    cfg_data = OmegaConf.load(config_file)
    cluster_discovery = cfg_data.vcl.cluster.discovery
    cfg_data = cfg_data.dataset


    if cfg.cls:
        token_idx = None
    else:
        token_idx = load_concept_activation(config_path, None, train=False, cluster_assignment='',filename_root='token_idx.npy', take_parent=False)
    soft_assignments = load_concept_activation(config_path, None, train=True, cluster_assignment=ASSIGNMENT_DICT[cluster_discovery],filename_root='clustering.npy', take_parent=False)
    try:
        hard_assignments = load_concept_activation(config_path, None, train=True, cluster_assignment='hard_clustering',filename_root='clustering.npy', take_parent=False)
    except:
        hard_assignments = soft_assignments.argmax(axis=1)



    if cfg.align_experiment:
        vis_selection_file = os.path.join(cfg.visualization_dir, f'{cfg.run_id_align}_{cfg.run_id}.pkl')
        with open(vis_selection_file, 'rb') as f:
            vis_selection_dict = pickle.load(f)
    else:
        vis_selection_dict = {'all_cluster': np.arange(soft_assignments.shape[1])}

    if not 'pretrained' in cfg_data.params:
        cfg_data.params.pretrained = True
    if not 'model_ckpt_path' in cfg_data.params:
        cfg_data.params.model_ckpt_path = ''
    if not 'only_noise' in cfg_data.params:
        cfg_data.params.only_noise = False

    cfg_data.params.root = cfg.data_root
    feature_layer = cfg_data.params.feature_layer
    cfg_data.params.feature_layer = 0
    dataset,_ = create_dataset(cfg_data, return_label=True, cuda=False, train=True, indices_subsample=None)
    sample_idx = dataset.indices
    # repeat as often as token were selected from one image
    if not cfg.cls and int(cfg_data.subsample_ratio*121)>1:
        sample_idx = np.repeat(sample_idx, repeats=int(cfg_data.subsample_ratio*196/49))
    print(sample_idx)

    n_cols = cfg.n_examples//cfg.n_rows

    if cfg.align_experiment:
        visualization_dir = os.path.join(cfg.visualization_dir,str(cfg.run_id_align),str(cfg.run_id))
    else:
        model_id = MODEL_ID_DICT[cfg_data.params.representation_model_ckpt]
        visualization_dir = os.path.join(cfg.visualization_dir,str(model_id),f'{feature_layer}_{cluster_discovery}_{cfg.run_id}')
    os.makedirs(visualization_dir)

    for meta_cluster_idx in vis_selection_dict:
        vis_selection = vis_selection_dict[meta_cluster_idx]
        for i,cluster_idx in enumerate(vis_selection):
            if cluster_idx is None:
                continue
            fig, ax = plt.subplots(cfg.n_rows, n_cols, figsize=(5*n_cols//2,5*n_cols//2))
            ax = ax.flatten()
            show_train_patches(cluster_idx, 
                                soft_assignments, 
                                hard_assignments,
                                token_idx=token_idx,
                                sample_idx=sample_idx ,
                                n_samples=cfg.n_examples,
                                cfg_data=cfg_data,
                                n_patches=121,
                                select_from_all_samples=True if cluster_discovery in ('pca','mcd') else False, # some concepts are not assi
                                random=False,
                                title=False,
                                ax=ax)     
            plt.subplots_adjust(wspace=0, hspace=0)
            plt.tight_layout()
            if cfg.align_experiment:
                plt.savefig(os.path.join(visualization_dir,f'{cfg.run_id_align}_{cfg.run_id}_{meta_cluster_idx}_{i}_{cluster_idx}.svg'))
            else:
                plt.savefig(os.path.join(visualization_dir,f'{cluster_idx}.svg'))

            plt.close()


if __name__ == "__main__":
    base_conf = OmegaConf.load("./source/conf/cluster_visualization.yaml")
    cli_conf = OmegaConf.from_cli()
    now = datetime.now()
    now_conf = OmegaConf.create({"now_dir": f"{now:%Y-%m-%d}/{now:%H-%M-%S}"})
    # merge them all
    conf = OmegaConf.merge(now_conf, base_conf, cli_conf)
    main(conf)
