import torch
from torch.utils.data import DataLoader, Subset, ConcatDataset
import argparse
from argparse import Namespace
import sys
import os
sys.path.insert(1, os.path.dirname(os.getcwd()))
import parameters    as par
import metrics
import datasets
import criteria
import architectures as archs
from datasets.basic_dataset_scaffold import BaseDataset
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
import sklearn
from utilities import logger
from tqdm import tqdm
import math
import pandas as pd
import numpy as np
import json
import inspect
import itertools

"""============================================================================"""
############ INPUT ARGUMENTS ##############
parser = argparse.ArgumentParser()
parser = par.downstream_parameters(parser)

args = parser.parse_args()

args.attribute = ' '.join(args.attribute)
logger.print_args(args)

"""============================================================================"""
############ HELPFUL UTIL FOR JSON SERIALIZATION OF NUMPY ARRAYS ################
def default(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError('Not serializable')
    
"""============================================================================"""
########### HELPFUL UTIL FOR CONVERTING KMEANS CLUSTERS TO NUMPY ARRAYS #########
def convert(kmeans_preds, targets):
    _, idx = np.unique(kmeans_preds, return_index = True)
    uk = kmeans_preds[np.sort(idx)]
    _, idx = np.unique(targets, return_index = True)
    ut = targets[np.sort(idx)]
    _, idx = np.unique(uk, return_index = True)
    preds = ut[idx][kmeans_preds]
    return preds

"""============================================================================"""
############ Iterate through path provided ##############

jobs = []

for file in os.listdir(args.filepath):
    print(file)
    PATH = os.path.join(args.filepath, file)
    
    ################### BOOKKEEPING DATA ########################
    job_id = file[:-8].split("_")[1]
    
    if not os.path.isfile(PATH) or job_id in jobs:
        continue
        
    ld = torch.load(PATH)
    opt = Namespace(**ld["opt"])
    
    #### NEED AN EXTRA CHECK FOR CUB200
    if opt.dataset == "cub200" and opt.imbalance and opt.config["attribute"]["type"] != args.attribute:
        continue
    
    ##################### NETWORK SETUP ##################
    if "parade" in opt.method:
        opt.method = "saparade"
        net = 'multifeature_resnet50' if 'resnet' in opt.arch else 'multifeature_bninception'
    else:
        net = opt.arch
    model      = archs.select(net, opt)
    if args.evaluate_on_gpu and args.gpu is not None:
        args.device = torch.device("cuda")
        model  = model.to(args.device)
    else:
        args.device = torch.device("cpu")
    model.load_state_dict(ld["model_state_dict"])

    #################### DATALOADER SETUPS ##################
    opt.use_tv_split = False
    dataloaders = {}
    
    ####### REBALANCE THE DATASET FOR TESTING
    flag=False
    if opt.imbalance:
        opt.imbalance = False
        flag = True
    dsets    = datasets.select(opt.dataset, opt, opt.source_path)
    if flag:
        opt.imbalance = True
    
    dataloaders["train"] = DataLoader(dsets["evaluation"], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)
    dataloaders["test"]    = DataLoader(dsets["testing"],    num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)
    
    dset = ConcatDataset([dsets["evaluation"], dsets["testing"]])
    dset.image_list = dsets["evaluation"].image_list + dsets["testing"].image_list
    dset.image_paths = dset.image_list
    dset.image_dict = {key: dsets["evaluation"].image_dict.get(key, [])
               + [[x[0], len(dsets["evaluation"].image_list)+x[1]] for x in dsets["testing"].image_dict.get(key, [])]
               for key in set.union(set(list(dsets["evaluation"].image_dict.keys())), set(list(dsets["testing"].image_dict.keys())))}
    if hasattr(dsets["evaluation"], "metadata") and hasattr(dsets["testing"], "metadata"):
        dset.metadata = pd.concat([dsets["evaluation"].metadata, dsets["testing"].metadata])
    
    dataloaders["combined"] = DataLoader(dset, num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)
    
    split_indices = {}
    if args.attribute == "class":
        subsets = ["overall", "reduced", "inflated", "gap"]
        balance = int(not opt.imbalance)
        for key in dataloaders:
            targets = list(map(lambda x: x[0], dataloaders[key].dataset.image_list))
            if opt.imbalance:
                split_indices[key] = [1 if dataloaders[key].dataset.image_list[i][1] not in opt.config["min_classes"] else 0 for i in range(len(targets))]
            else:
                assert len(args.seeds)
                split_indices[key] = []
                for seed in args.seeds:
                    np.random.seed(seed)
                    min_classes = np.random.randint(low=0, high=opt.n_classes, size = 50)
                    split_indices[key].append([1 if dataloaders[key].dataset.image_list[i][1] not in min_classes else 0 for i in range(len(targets))])
        index = pd.MultiIndex.from_arrays([args.eval_splits, [balance]*len(args.eval_splits), [job_id]*len(args.eval_splits)], names=["split", "balance", "job_id"])
    else:
        balance = None
        unique_attrs = np.unique(list(itertools.chain.from_iterable([list(dataloaders[key].dataset.metadata[args.attribute]) for key in dataloaders])))
        subsets = ["overall", *["{}_{}".format(args.attribute, unique_attr) for unique_attr in unique_attrs], "gap"]
        for key in dataloaders:
            assert hasattr(dataloaders[key].dataset, "metadata")
            assert args.attribute in dataloaders[key].dataset.metadata.columns
            split_indices[key] = dataloaders[key].dataset.metadata[args.attribute].values.tolist()
        if args.attribute == "color" and opt.dataset == "cub200":
            if hasattr(opt, "config_file") and opt.config_file:
                balance = os.path.basename(opt.config_file).split("_")[0]
            else:
                balance = True
            index = pd.MultiIndex.from_arrays([args.eval_splits, [balance]*len(args.eval_splits), [job_id]*len(args.eval_splits)], names=["split", "balance", "job_id"])
        else:
            index = pd.MultiIndex.from_arrays([args.eval_splits, [job_id]*len(args.eval_splits)], names=["split", "job_id"])
    
    columns = pd.MultiIndex.from_product([subsets, args.evaluation_metrics], names=["subset", "metric"])
    embeds_df = pd.DataFrame(index=index, columns=columns, dtype=float)
    
    #################### METRIC COMPUTER ####################
    args.rho_spectrum_embed_dim = opt.rho_spectrum_embed_dim
    metric_computer = metrics.MetricComputer(args.evaluation_metrics, args)

    #################### EVALUATION #########################
    _ = model.eval()
    
    features_dict = {}
    targets_dict = {}
    if opt.imbalance or args.attribute != "class":
        indices = []
    else:
        indices = np.array([], dtype=np.int64).reshape(len(args.seeds), 0)

    if len(opt.evaltypes) == 1:
        evaltype = opt.evaltypes[0]
    else:
        raise NotImplementedError("Not yet implemented multiple evaluation types")

    for key, dataloader in dataloaders.items():
        opt.n_classes = len(dataloader.dataset.image_dict)

        computed_metrics, extra_infos = metric_computer.compute_standard(opt, model, dataloader, opt.evaltypes, args.device, return_input_dicts=True)
        computed_metrics_dict = {}

        features_dict[key] = extra_infos[evaltype]["features"]
        targets_dict[key] = extra_infos[evaltype]["target_labels"]
        
        if key not in args.eval_splits:
            continue
        
        if args.attribute == "class":
            if opt.imbalance:
                indices.extend(split_indices[key])

                computed_metrics_M, _ = metric_computer.compute_standard(opt, model, dataloader, opt.evaltypes, args.device, indices = np.where(split_indices[key])[0].tolist())
                computed_metrics_m, _ = metric_computer.compute_standard(opt, model, dataloader, opt.evaltypes, args.device, indices = np.where(np.logical_not(split_indices[key]))[0].tolist())
            else:
                indices = np.hstack([indices, split_indices[key]])

                M_array = {key: [] for key in computed_metrics[evaltype]}
                m_array = {key: [] for key in computed_metrics[evaltype]}

                for split_index_list in split_indices[key]:
                    computed_metrics_M, _ = metric_computer.compute_standard(opt, model, dataloader, opt.evaltypes, args.device, indices = np.where(split_index_list)[0].tolist())
                    computed_metrics_m, _ = metric_computer.compute_standard(opt, model, dataloader, opt.evaltypes, args.device, indices = np.where(np.logical_not(split_index_list))[0].tolist())
                    M_array = {key: values + [computed_metrics_M[evaltype][key]] for key, values in M_array.items()}
                    m_array = {key: values + [computed_metrics_m[evaltype][key]] for key, values in m_array.items()}

                computed_metrics_M = {evaltype: {key: np.mean(values) for key, values in M_array.items()}}
                computed_metrics_m = {evaltype: {key: np.mean(values) for key, values in m_array.items()}}
            
            computed_metrics_gap = {key: computed_metrics_M[evaltype][key]-computed_metrics_m[evaltype][key] for key in set.union(set(computed_metrics_M[evaltype].keys()), set(computed_metrics_m[evaltype].keys()))} # inflated - reduced
        else:
            indices.extend(split_indices[key])
            
            for ul in np.unique(split_indices[key]):
                computed_metrics_dict[ul], _ = metric_computer.compute_standard(opt, model, dataloader, opt.evaltypes, args.device, indices = np.where(split_indices[key] == ul)[0].tolist())

        for metric in embeds_df.columns.get_level_values("metric").unique():
            if args.attribute == "class":
                index_tuple = (key, balance, job_id)
                embeds_df[("inflated", metric)][index_tuple] = computed_metrics_M[evaltype][metric]
                embeds_df[("reduced", metric)][index_tuple] = computed_metrics_m[evaltype][metric]
                embeds_df[("gap", metric)][index_tuple] = computed_metrics_gap[metric]
            else:
                if args.attribute == "color" and opt.dataset == "cub200":
                    index_tuple = (key, balance, job_id)
                else:
                    index_tuple = (key, job_id)
                for ul in np.unique(split_indices[key]):
                    embeds_df[("{}_{}".format(args.attribute, ul), metric)][index_tuple] = computed_metrics_dict[ul][evaltype][metric]
                embeds_df[("gap", metric)][index_tuple] = np.nan
            embeds_df[("overall", metric)][index_tuple] = computed_metrics[evaltype][metric]
    
    ####### ENSURE THAT DOWNSTREAM TRAINING SET HAS ALL AVAILABLE CLASSES #############
    unique_labels = np.unique(targets_dict["combined"].flatten())
    index_dict = {label: np.where(targets_dict["combined"] == label)[0] for label in unique_labels}
    
    train_indices = list(itertools.chain.from_iterable([values[:math.ceil(len(values)/2)] for key, values in index_dict.items()]))
    test_indices = list(itertools.chain.from_iterable([values[math.floor(len(values)/2):] for key, values in index_dict.items()]))

    indices = split_indices["combined"]
    dindices = np.array(indices)[..., test_indices].tolist()
    
    train_features = features_dict["combined"][train_indices].cpu().numpy()
    train_targets = targets_dict["combined"][train_indices].flatten()
    
    test_features = features_dict["combined"][test_indices].cpu().numpy()
    test_targets = targets_dict["combined"][test_indices].flatten()
    
    index = pd.MultiIndex.from_arrays([args.downstream_models]+[[embeds_df.index.levels[i].item()]*len(args.downstream_models) for i in range(1, len(embeds_df.index.levels))], names=["model"]+list(index.names[1:]))
    columns = pd.MultiIndex.from_product([subsets, args.downstream_evaluation_metrics], names=["subset", "metric"])
    downstream_df = pd.DataFrame(index=index, columns=columns, dtype=float)
    
    #### ENSURE KMEANS IS LAST
    if "kmeans" in args.downstream_models:
        args.downstream_models[-1], args.downstream_models[args.downstream_models.index("kmeans")] = args.downstream_models[args.downstream_models.index("kmeans")], args.downstream_models[-1]
    
    for downstream_model in args.downstream_models:
        if "svm" in downstream_model:
            method = SVC(verbose=True)
        elif "lr" in downstream_model:
            method = LogisticRegression(verbose=True, n_jobs=-1)
        elif "rf" in downstream_model:
            method = RandomForestClassifier(verbose=True, n_jobs=-1)
        elif "kmeans" in downstream_model:
            train_features = test_features = features_dict["combined"].cpu().numpy()
            train_targets = test_targets = targets_dict["combined"].flatten()
            dindices = indices
            method = KMeans(n_clusters=len(np.unique(train_targets)))
            
        print("Training downstream model {} on embeddings...".format(downstream_model))
        method.fit(train_features, train_targets)
        test_preds = method.predict(test_features)
        
        if "kmeans" in downstream_model:
            test_preds = convert(test_preds, test_targets)

        for metric in args.downstream_evaluation_metrics:
            assert hasattr(sklearn.metrics, metric+"_score")
            aux_inp = {}
            metric_func = getattr(sklearn.metrics, metric+"_score")
            if "average" in inspect.getfullargspec(metric_func).args and np.unique(test_targets).shape[0] > 2:
                aux_inp["average"] = "macro"
            output = metric_func(test_targets, test_preds, **aux_inp)
            output_dict = {}
            if args.attribute == "class":
                if opt.imbalance:
                    output_M = metric_func(test_targets[np.array(dindices, dtype=bool)], test_preds[np.array(dindices, dtype=bool)], **aux_inp)
                    output_m = metric_func(test_targets[np.logical_not(dindices)], test_preds[np.logical_not(dindices)], **aux_inp)
                else:
                    output_M = np.mean([metric_func(test_targets[np.array(index_list).astype(bool)], test_preds[np.array(index_list).astype(bool)], **aux_inp) for index_list in dindices])
                    output_m = np.mean([metric_func(test_targets[np.logical_not(index_list)], test_preds[np.logical_not(index_list)],  **aux_inp) for index_list in dindices])
                output_gap = output_M-output_m # inflated - reduced
            else:
                for ul in np.unique(dindices):
                    output_dict[ul] = metric_func(test_targets[np.where(dindices == ul)[0]], test_preds[np.where(dindices == ul)[0]], **aux_inp)
                output_gap = np.nan
            if args.attribute == "class":
                index_tuple = (downstream_model, balance, job_id)
                downstream_df[("inflated", metric)][index_tuple] = output_M
                downstream_df[("reduced", metric)][index_tuple] = output_m
            else:
                if args.attribute == "color" and opt.dataset == "cub200":
                    index_tuple = (downstream_model, balance, job_id)
                else:
                    index_tuple = (downstream_model, job_id)
                for ul in np.unique(dindices):
                    downstream_df[("{}_{}".format(args.attribute, ul), metric)][index_tuple] = output_dict[ul]
            downstream_df[("overall", metric)][index_tuple] = output
            downstream_df[("gap", metric)][index_tuple] = output_gap

    ### Create CSV file and write metrics dictionary
    JOB_PATH = os.path.join(*[args.save_path, opt.dataset, "CSV_output", "jobs", job_id])
    if opt.dataset == "cub200" and not opt.imbalance:
        JOB_PATH = os.path.join(JOB_PATH, args.attribute)
    CSV_PATH = os.path.join(*[JOB_PATH, "embed.csv"])
    DOWNSTREAM_CSV_PATH = os.path.join(*[JOB_PATH, "downstream.csv"])
    ARGS_PATH = os.path.join(*[JOB_PATH, "downstream.json"])
    
    JSON_PATH = os.path.join(*[JOB_PATH, "hparam.json"])
    
    os.makedirs(JOB_PATH, exist_ok = True)
        
    embeds_df.to_csv(CSV_PATH, sep=",")
    downstream_df.to_csv(DOWNSTREAM_CSV_PATH, sep=",")
    
    with open(JSON_PATH, "w") as fp:
        output_dict = dict(vars(opt))
        output_dict.pop("device", None)
        json.dump(output_dict, fp, default=default)
        
    with open(ARGS_PATH, "w") as fp:
        output_dict = dict(vars(args))
        output_dict.pop("device", None)
        json.dump(output_dict, fp, default=default)
