import sys
from typing import List

import gin
import gin.tf
import numpy as np
import torch

gin.enter_interactive_mode()
import time
import os

sys.path.insert(0, './../../thirdparty')
# needed later:
from disentanglement_lib.evaluation.metrics import beta_vae
from disentanglement_lib.evaluation.metrics import dci
from disentanglement_lib.evaluation.metrics import downstream_task
from disentanglement_lib.evaluation.metrics import factor_vae
from disentanglement_lib.evaluation.metrics import fairness
from disentanglement_lib.evaluation.metrics import irs
from disentanglement_lib.evaluation.metrics import mig
from disentanglement_lib.evaluation.metrics import modularity_explicitness
from disentanglement_lib.evaluation.metrics import reduced_downstream_task
from disentanglement_lib.evaluation.metrics import sap_score
from disentanglement_lib.evaluation.metrics import unsupervised_metrics

import thirdparty.mcc_metric.metric as mcc

stupid_linting = [beta_vae, dci, downstream_task, factor_vae, fairness, irs,
                  modularity_explicitness, mig, reduced_downstream_task,
                  sap_score, unsupervised_metrics, mcc
                  ]


def evaluate_dislib(epoch: int,
                    model,
                    dataset,
                    split_name: str,
                    writer,
                    dis_lib_metrics: List[str],
                    all_metrics: bool = False, supervised=False):
    if not all_metrics:
        slow = ['dci']
        metrics = [f'{m}.gin' for m in dis_lib_metrics if
                   m not in slow]
    else:
        metrics = [f'{m}.gin' for m in dis_lib_metrics]
    with torch.no_grad():
        out = evaluate(model, dataset, metrics, supervised=supervised)
    metric_names = {'mcc': 'meanabscorr',
                    'dci': 'disentanglement',
                    'mig': 'discrete_mig',
                    'sap_score': 'SAP_score',
                    'modularity_explicitness': 'modularity_score'
                    }
    for metric, infos in out.items():
        writer.add_scalar(f'dis_lib_{split_name}/{metric}',
                          infos[metric_names[metric]],
                          epoch)
    return out


def evaluate(model, dataset, eval_config_files=None, device='cuda',
             supervised=True):
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    def model_wrapper(x):
        if x.shape[-1] == 3 or x.shape[-1] == 1:
            x = np.transpose(x, (0, 3, 1, 2))
        representation = model(torch.from_numpy(x).float().to(device))
        # only compute metric for factors with more than one value
        if supervised:
            representation = representation.T[dataset.active_factors].T
        return np.array(representation.detach().cpu())

    @gin.configurable("evaluation")
    def evaluate_model(
            evaluation_fn=gin.REQUIRED,
            random_seed=gin.REQUIRED):
        experiment_timer = time.time()
        results_dict = evaluation_fn(
            dataset,
            model_wrapper,
            random_state=np.random.RandomState(random_seed))
        results_dict["elapsed_time"] = time.time() - experiment_timer
        # results.update_result_directory(output_dir, "evaluation", results_dict)
        return results_dict

    random_state = np.random.RandomState(0)
    config_dir = '/home/anonymous/dev/projects/GeneralizationStudy/metric_configs/'

    if eval_config_files is None:
        eval_config_files = [f for f in os.listdir(config_dir) if
                             not (f.startswith('.') or 'others' in f)]

    all_results = {}
    for eval_config in eval_config_files:
        t0 = time.time()

        metric_name = os.path.basename(eval_config).replace(".gin", "")

        eval_bindings = [
            f'evaluation.random_seed = {random_state.randint(2 ** 32)}']
        gin.parse_config_files_and_bindings(
            [os.path.join(config_dir, eval_config)], eval_bindings)

        out = evaluate_model()
        gin.clear_config()
        # print(f'metric {metric_name} took', time.time() - t0, 's')
        all_results[metric_name] = out
    return all_results
