"""
Selects best set of hyperparameters based on metric, 
aggregates results over runs with best hyperparameters across seeds
"""


from typing import DefaultDict, List, Dict, Any
import os
import json
import torch
from omegaconf import OmegaConf
from pathlib import Path, PosixPath


class BestRuns:
    """Selects the best hyperparameters based on validation_metric then
    identifies all runs with best hyperparameters across seeds

    Args:
        sweep_dir: logs directory where runs are stored
        validation_metric: metric to use for selecting the best run
    """

    def __init__(
        self,
        sweep_dir: str,
        validation_metric: str = "val_canonical_loss",
        minimum=True,
        load_aposteriori: bool = True,
        param=None,
        param_value=None,
    ):
        self.sweep_dir = sweep_dir
        self.validation_metric = validation_metric
        self.minimum = minimum

        self.run_dirs: List[os.DirEntry] = self.get_run_dirs(sweep_dir)
        self.id_to_metrics = self.load_metrics()
        self.id_to_parameters = self.load_hyper_parameters()

        self.best_run_id: str = self.find_best_run_id(param, param_value)
        self.best_run_ids: List = self.find_matching_runs(self.best_run_id)

        self.id_to_results = self.load_results()
        if load_aposteriori:
            self.id_to_results_apost = self.load_results_a_posterori()

    def load_metrics(self) -> Dict[str, dict]:
        id_to_metrics = dict()
        for run in self.run_dirs:
            run_metrics = dict()
            run_metrics.update(self.load_metrics_json(run.path, prefix="train_"))
            run_metrics.update(self.load_metrics_json(run.path, prefix="eval_"))
            id_to_metrics[run.name] = run_metrics
        return id_to_metrics

    def load_results(self) -> Dict[str, dict]:
        id_to_results = dict()
        for run in self.run_dirs:
            run_results = dict()
            run_results.update(self.load_results_json(run.path, prefix="train_"))
            run_results.update(self.load_results_json(run.path, prefix="eval_"))
            id_to_results[run.name] = run_results
        return id_to_results

    def load_results_a_posterori(self) -> Dict[str, dict]:
        id_to_results_apost = dict()
        for run in self.run_dirs:
            run_results = dict()
            try:
                run_results.update(
                    self.load_results_json(run.path, prefix="a_posteriori_")
                )
                id_to_results_apost[run.name] = run_results
            except:
                print(f"A posteriori results not found for {run}")
        return id_to_results_apost

    @property
    def best_run_hyperparameters(self) -> Dict[str, dict]:
        return self.id_to_parameters[self.best_run_id]

    def get_run_dirs(self, sweep_dir: str) -> List[os.DirEntry]:
        run_dirs = [run for run in os.scandir(sweep_dir) if self._is_valid_run_dir(run)]
        return run_dirs

    def _is_valid_run_dir(
        self, run_dir: os.DirEntry, skip_if_no_json: bool = True
    ) -> bool:
        """Checks whether directory contain results of a run"""
        if not run_dir.is_dir():
            return False
        if run_dir.name == ".submitit":
            return False
        files = os.listdir(run_dir)
        if "train_metrics.json" not in files:
            print(f"train_metrics.json not found skipping {run_dir.path}")
            return False
        if "eval_metrics.json" not in files:
            print(f"eval_metrics.json not found skipping {run_dir.path}")
            return False
        return True

    @staticmethod
    def load_metrics_json(run_dir: str, prefix="") -> dict:
        metric_path = os.path.join(run_dir, f"{prefix}metrics.json")
        with open(metric_path) as f:
            metrics = json.load(f)
        return metrics

    @staticmethod
    def load_results_json(run_dir: str, prefix="") -> dict:
        results_path = os.path.join(run_dir, f"{prefix}results.json")
        with open(results_path) as f:
            results = json.load(f)
        return results

    def load_hyper_parameters(self) -> Dict[str, dict]:
        id_to_params = dict()
        for run in self.run_dirs:
            ckpt = self._load_ckpt(run.path)
            params = ckpt["hyper_parameters"]
            if "datamodule_hparams" in ckpt:
                params.update(ckpt["datamodule_hparams"])
            extra_configs = self.get_extra_run_configs(run)
            if extra_configs:
                params.update(extra_configs)
            id_to_params[run.name] = params
        return id_to_params

    def get_extra_run_configs(self, run: os.DirEntry) -> dict:
        """Used to return extra configs to store in hyperparameters for analysis"""
        return dict()

    def _load_ckpt(self, run_dir: str):
        ckpt_path = None
        for filename in os.scandir(run_dir):
            if filename.name.startswith("best") and filename.name.endswith("ckpt"):
                ckpt_path = filename.path
        if not ckpt_path:
            raise ValueError(f"no best[...].ckpt found in {run_dir}")
        params = torch.load(ckpt_path, map_location="cpu")
        return params

    def find_best_run_id(self, param=None, param_value=None) -> str:
        if param is not None:
            filtered_metrics = self.id_to_metrics
            for i, p in enumerate(param):
                filtered_metrics = {
                    k: v
                    for k, v in filtered_metrics.items()
                    if self.id_to_parameters[k][p] == param_value[i]
                }
        else:
            filtered_metrics = self.id_to_metrics
        if self.minimum:
            return min(
                filtered_metrics,
                key=lambda k: filtered_metrics[k][self.validation_metric],
            )
        else:
            return max(
                filtered_metrics,
                key=lambda k: filtered_metrics[k][self.validation_metric],
            )

    def find_matching_runs(self, run_id: str) -> List[str]:
        """Returns a list of runs with matching hyper- parameters"""
        best_params = self.id_to_parameters[run_id]
        matching_ids = []

        for run_id, params in self.id_to_parameters.items():
            if params == best_params:
                matching_ids.append(run_id)
        return matching_ids

    def get_best_runs_metric(self, name: str) -> List[Any]:
        """Returns metric values for all runs"""
        metrics = []
        for run_id in self.best_run_ids:
            metrics.append(self.id_to_metrics[run_id][name])
        return metrics

    def get_best_runs_result(self, name: str) -> List[Any]:
        results = []
        for run_id in self.best_run_ids:
            results.append(self.id_to_results[run_id][name])
        return results

    def get_best_runs_result_apost(self, name: str) -> List[Any]:
        results = []
        for run_id in self.best_run_ids:
            results.append(self.id_to_results_apost[run_id][name])
        return results


class DiverseRun(BestRuns):
    """Stores metrics and parameters for runs with varying training diversity.

    Args:
        run_dir: checkpoint directory where results for the run are saved
    """

    def __init__(
        self,
        run_dir: str,
        metric_names=[
            "train_canonical_top_1_accuracy",
            "train_diverse_2d_top_1_accuracy",
            "train_diverse_3d_top_1_accuracy",
            "val_canonical_top_1_accuracy",
            "val_diverse_2d_top_1_accuracy",
            "val_diverse_3d_top_1_accuracy",
            "test_canonical_top_1_accuracy",
            "test_diverse_2d_top_1_accuracy",
            "test_diverse_3d_top_1_accuracy",
            "diverse_2d_train_canonical_top_1_accuracy",
            "diverse_3d_train_canonical_top_1_accuracy",
        ],
    ):
        self.run_dir = run_dir

        self.run_dirs: List[os.DirEntry] = self.get_run_dirs(run_dir)
        self.id_to_metrics = self.load_metrics()
        self.id_to_parameters = self.load_hyper_parameters()

        self.metric_names = metric_names

        self.prop_to_ids: Dict[float, List[str]] = self.build_prop_to_ids()
        self.prop_to_metrics: Dict[
            float, Dict[str, list]
        ] = self.build_prop_to_metrics()

    def get_extra_run_configs(self, run: os.DirEntry) -> dict:
        """Returns train_prop_to_vary"""
        run_config_path = os.path.join(run.path, "run_configs.yaml")
        with open(run_config_path) as f:
            config = OmegaConf.load(f)
        extra_configs = {
            "train_prop_to_vary": config.data_module.train_prop_to_vary,
        }
        return extra_configs

    def build_prop_to_ids(self) -> Dict[float, List[str]]:
        prop_to_ids: Dict[float, list] = dict()
        for run_id, params in self.id_to_parameters.items():
            prop = params["train_prop_to_vary"]
            current_ids = prop_to_ids.get(prop, [])
            new_ids = current_ids + [run_id]
            prop_to_ids[prop] = new_ids
        return prop_to_ids

    def build_prop_to_metrics(self) -> Dict[float, dict]:
        prop_to_metrics: Dict[float, dict] = DefaultDict(dict)
        for prop in self.prop_to_ids:
            run_ids = self.prop_to_ids[prop]
            for run_id in run_ids:
                for metric_name in self.metric_names:
                    metric = self.id_to_metrics[run_id][metric_name]
                    current_metrics = prop_to_metrics[prop].get(metric_name, [])
                    new_metrics = current_metrics + [metric]
                    prop_to_metrics[prop][metric_name] = new_metrics
        return prop_to_metrics


class BestRunsLieLoss(BestRuns):

    def __init__(
        self,
        sweep_dir: str,
        validation_metric: str = "online_val_combined_top_1_accuracy",
        loss_constraint: str = "val_diverse_2d_lie_loss_epoch",
        minimum=True,
        param=None,
        param_value=None,
    ):
        self.sweep_dir = sweep_dir
        self.validation_metric = validation_metric
        self.minimum = minimum

        self.run_dirs: List[os.DirEntry] = self.get_run_dirs(sweep_dir)
        self.id_to_metrics = self.load_metrics()
        self.start_metrics = self.load_metrics_start()
        self.id_to_parameters = self.load_hyper_parameters()

        self.best_run_id: str = self.find_best_run_id_constrained(loss_constraint, param, param_value)
        self.best_run_ids: List = self.find_matching_runs(self.best_run_id)

        self.id_to_results = self.load_results()

    def load_metrics_start(self):
        """Load metrics from last checkpoint instead of json"""
        start_metrics = dict()

        for run in self.run_dirs:
            ckpt_paths = list(Path(run).glob("first_*.ckpt"))
            assert (
                len(ckpt_paths) == 1
            ), f"multiple or no first checkpoints found for {ckpt_paths}"
            ckpt_path = ckpt_paths[0]

            run_metrics = self.load_checkpoint_metrics(ckpt_path)
            start_metrics[run.name] = run_metrics
        return start_metrics 

    def load_checkpoint_metrics(self, ckpt_path: PosixPath) -> dict():
        """Returns metrics from the checkpoint"""
        metrics = torch.load(ckpt_path)["metrics"]
        for name, metric in metrics.items():
            metrics[name] = metric.item()
        return metrics
    
    def find_best_run_id_constrained(self, loss_constraint, param=None, param_value=None, ) -> str:

        # filter on param
        if param is not None:
            filtered_metrics = self.id_to_metrics
            for i, p in enumerate(param):
                filtered_metrics = {
                    k: v
                    for k, v in filtered_metrics.items()
                    if self.id_to_parameters[k][p] == param_value[i]
                }
        else:
            filtered_metrics = self.id_to_metrics

        if loss_constraint is not None:
            # filter on metrics 
            filtered_metrics = {
                    k: v
                    for k, v in filtered_metrics.items()
                    if self.id_to_metrics[k][loss_constraint] <= self.start_metrics[k][loss_constraint]*0.6
                }
        
        if self.minimum:
            return min(
                filtered_metrics,
                key=lambda k: filtered_metrics[k][self.validation_metric],
            )
        else:
            return max(
                filtered_metrics,
                key=lambda k: filtered_metrics[k][self.validation_metric],
            )