from pathlib import Path
import hydra
from omegaconf import OmegaConf, open_dict

from ltsgns_mp.envs import Env
from ltsgns_mp.util.initialization import main_initialization
from ltsgns_mp.util.own_types import ConfigDict
from ltsgns_mp.util.util import load_omega_conf_resolvers


class CustomPrinter:
    def __init__(self, verbosity: int):
        self._verbosity = verbosity

    def print(self, string: str, threshold: int = 0):
        if threshold > self._verbosity:
            print("  " * threshold + string)

    def __call__(self, string: str, threshold: int = 0):
        self.print(string, threshold)


class Evaluator:
    def __init__(self, evaluation_config: ConfigDict):
        """
        Evaluates the results of an experiment. The experiment is assumed to be structured as follows:
        root_path(usually the date of the experiment scheduled)
        ├── <group_name>
        │   ├── <job_type>
        │   │   ├── seed_00
        │   │   │   ├── checkpoints
        │   │   │   │   ├── checkpoint.ckpt
        │   │   │   ├── .hydra
        │   │   │   │   ├── config.yaml
        │   │   │   ├── ...
        │   │   ├── seed_01
        │   │   │   ├── checkpoints
        │   │   │   │   ├── checkpoint.ckpt
        │   │   │   ├── .hydra
        │   │   │   │   ├── config.yaml
        │   │   │   ├── ...
        │   │   ├── ...


        Args:
            root_path:
            evaluation_config: Hydra config for the current new evaluation. Here you can overwrite certain stuff from the
                config.yaml file in the root_path. For example, you can overwrite the evaluation section to evaluate
        """
        self._evaluation_config = evaluation_config
        root_path = evaluation_config.loading.root_path
        self._exp_name = evaluation_config.exp_name
        self._seeds = evaluation_config.evaluation.seeds
        if isinstance(self._seeds, int):
            self._seeds = [self._seeds]
        checkpoint_iteration = evaluation_config.loading.checkpoint_iteration
        if isinstance(root_path, str):
            root_path = Path(root_path)
        self._root_path = root_path
        self._checkpoint_iteration = checkpoint_iteration
        self.printer = CustomPrinter(-1)
        self._blender_mode = evaluation_config.evaluation.blender_mode

    def evaluate_experiment(self):
        experiment_path = self._root_path / self._exp_name  # Path overloads the / operator to join paths
        if not experiment_path.exists():
            raise FileNotFoundError(f"Path '{experiment_path}' not found.")
        subexperiment_paths = [x for x in experiment_path.iterdir() if x.is_dir()]

        if "checkpoints" in [subexperiment_path.name for subexperiment_path in subexperiment_paths]:
            # if there was no multirun, there is only one subexperiment, so we can evaluate it directly
            self._evaluate_repetition(experiment_path)
        else:
            self.printer(f"Experiment '{self._exp_name}' contains {len(subexperiment_paths)} subexperiments.", 0)
            # multiple subexperiments, evaluate each one separately
            for subexperiment_path in subexperiment_paths:
                self.printer(f"Evaluating subexperiment '{subexperiment_path.name}'...", 1)
                self._evaluate_subexperiment(subexperiment_path)

    def _evaluate_subexperiment(self, subexperiment_path: Path):
        """
        Evalautes a single subexperiment (i.e., one list/grid entry from the experiment grid) for all repetitions.
        Args:
            subexperiment_path: Path to the subexperiment folder

        Returns:

        """
        assert subexperiment_path.exists(), f"Path '{subexperiment_path}' not found."

        seeds = sorted([x for x in subexperiment_path.iterdir() if x.is_dir()])
        job_type = subexperiment_path.name
        env = None
        if self._seeds is None:
            for seed in seeds:
                env = self._evaluate_repetition(seed, env, job_type)
        else:
            for seed in seeds:
                if int(seed.name.split("_")[-1]) in self._seeds:
                    env = self._evaluate_repetition(seed, env, job_type)

    def _evaluate_repetition(self, seed: Path, env: Env | None = None, job_type: str | None = None) -> Env | None:
        """
        Evaluates a single repetition of a subexperiment (i.e., one run of the experiment with a specific seed).
        Args:
            seed: Path to the final experiment folder. Should contain a .hydra folder and a checkpoints folder.
            env: The environment of the repetition, if it was already evaluated. Otherwise, None.

        Returns: The environment of the repetition, if it was successfully evaluated. Otherwise, None. This is used
                 to save time or the next seed, since the env does not need to be reinitialized.

        """
        config = self._get_config(seed, job_type)
        print(OmegaConf.to_yaml(config))
        # try:
        env, algorithm, evaluator, recorder = main_initialization(config, env)
        # take last epoch as current epoch
        evaluation_metrics = evaluator.eval_step(epoch=config.epochs, force_eval=True,
                                                 visualize_only=self._blender_mode)
        if self._blender_mode:
            output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
            vis_metrics = evaluation_metrics["visualizations"]
            for context_size, context_size_dict in vis_metrics.items():
                for task_name, task_dict in context_size_dict.items():
                    save_path = Path(output_dir) / "blender_output" / context_size / task_name
                    print(f"Saving visualizations to {save_path}")

                    trajectory_dict = task_dict["to_visualize"]
                    evaluation_data = trajectory_dict["eval_traj"]
                    mesh_faces = evaluation_data.mesh_faces.cpu().detach().numpy()

                    # save the predicted trajectory as obj files
                    predicted_trajectory = trajectory_dict["predicted_traj"].cpu().detach().numpy()
                    self._trajectory_to_obj(vertices_over_time=predicted_trajectory,
                                            faces=mesh_faces,
                                            save_path=save_path / "predicted")

                    # save the ground truth trajectory as obj files
                    ground_truth_trajectory = evaluation_data.context_node_positions[0].cpu().detach().numpy()
                    self._trajectory_to_obj(vertices_over_time=ground_truth_trajectory,
                                            faces=mesh_faces,
                                            save_path=save_path / "ground_truth")

                    # save the collider information
                    if "visual_collider_vertices" in evaluation_data.keys():
                        collider_positions = evaluation_data.visual_collider_vertices.cpu().detach().numpy()
                        collider_faces = evaluation_data.visual_collider_faces.cpu().detach().numpy()
                    elif "context_collider_positions" in evaluation_data.keys():
                        collider_positions = evaluation_data.context_collider_positions[0].cpu().detach().numpy()
                        collider_faces = evaluation_data.visual_collider_faces.cpu().detach().numpy()
                    else:
                        # no collider
                        collider_positions = None
                        collider_faces = None
                    if collider_positions is not None:
                        self._trajectory_to_obj(vertices_over_time=collider_positions,
                                                faces=collider_faces,
                                                save_path=save_path / "collider")

        recorder.record_iteration(iteration=config.epochs, recorded_values=evaluation_metrics)
        # close wandb, save the final model, ...
        recorder.finalize()
        return env
        # except Exception as e:
        #     print(f"Exception occurred in repetition {seed.name}: {e}")
        #     print(f"Evaluation of {seed.name} failed.")
        #     try:
        #         recorder.finalize()
        #     except UnboundLocalError:
        #         # recorder was not initialized yet, hence no finalize needed
        #         pass
        #     return None

    def _get_config(self, repetition, job_type: str | None = None):
        # load config.yaml
        with open(repetition / ".hydra" / "config.yaml") as file:
            current_config = OmegaConf.load(file)
        current_config = OmegaConf.merge(current_config, self._evaluation_config)
        if job_type is not None:
            current_config.recorder.wandb.job_type = job_type
        update_config_manually(current_config)
        with open_dict(current_config):
            current_config.loading.checkpoint_path = str(repetition / "checkpoints")
        self.printer(f"Current config: {current_config}", 2)
        return current_config

    def _trajectory_to_obj(self, vertices_over_time, faces, save_path: str):
        import os
        # save the trajectory as obj files
        os.makedirs(save_path, exist_ok=True)

        if vertices_over_time.shape[-1] == 2:
            # 2D trajectory, add z coordinate
            import numpy as np
            vertices_over_time = np.concatenate([vertices_over_time, np.zeros((*vertices_over_time.shape[:-1], 1))],
                                                axis=-1)

        for time_step, vertices in enumerate(vertices_over_time):
            filename = os.path.join(save_path, f"mesh_{time_step:03d}.obj")
            with open(filename, 'w') as file:
                for v in vertices:
                    file.write("v ")
                    file.write(" ".join([str(x) for x in v]))
                    file.write("\n")
                for f in faces:
                    file.write("f ")
                    # OBJ files are 1-indexed
                    file.write(" ".join([str(x + 1) for x in f]))
                    file.write("\n")



def update_config_manually(current_config: ConfigDict):
    # Stuff that should be here since the run was an older version. The current eval expect that these keys are present
    with open_dict(current_config):
        if "name" not in current_config.algorithm.simulator.gnn:
            print("WARNING: name not in config. Setting it to 'hmpn_gnn'.")
            current_config.algorithm.simulator.gnn.name = "hmpn_gnn"
        if "second_order_dynamics" not in current_config.env:
            print("WARNING: second_order_dynamics not in config. Setting it to False.")
            current_config.env.second_order_dynamics = False
        if "anchor_index_as_feature" not in current_config.algorithm.train_iterator:
            print("WARNING: anchor_index_as_feature not in config. Setting it to False.")
            current_config.algorithm.train_iterator.anchor_index_as_feature = False
            current_config.evaluation.eval_iterator.anchor_index_as_feature = False
        if "anchor_index_mode" not in current_config.algorithm.train_iterator:
            print(
                "WARNING: anchor_index_mode not in config. Setting it to 'first_context'. (Should not matter for eval, this is only a training parameter)")
            current_config.algorithm.train_iterator.anchor_index_mode = "first_context"
        if "posterior_learner" in current_config.algorithm and "pc_std" not in current_config.algorithm.posterior_learner.lnpdf.likelihood:
            print("WARNING: pc_std not in config. Setting it to 0.005.")
            current_config.algorithm.posterior_learner.lnpdf.likelihood.pc_std = 0.005
        if "posterior_learner" in current_config.algorithm and "regularization" not in current_config.algorithm.simulator.decoder:
            print("WARNING: regularization not in config. Setting dropout to 0.0.")
            current_config.algorithm.simulator.decoder.regularization = {}
            current_config.algorithm.simulator.decoder.regularization.dropout = 0.0
        if "input_mesh_noise" not in current_config.algorithm.train_iterator:
            print("WARNING: input_mesh_noise not in config. Setting it to 0.0.")
            current_config.algorithm.train_iterator.input_mesh_noise = 0.0
        if "trajectory_targets" not in current_config.algorithm.train_iterator:
            print("WARNING: trajectory_targets not in config. Setting it to False.")
            current_config.algorithm.train_iterator.trajectory_targets = False
        if "use_prodmp" not in current_config.algorithm.simulator:
            print("WARNING: use_prodmp not in config. Setting it to False.")
            current_config.algorithm.simulator.use_prodmp = False
        if "last_collider_as_feature" not in current_config.algorithm.train_iterator:
            print("WARNING: last_collider_as_feature not in config. Setting it to False.")
            current_config.algorithm.train_iterator.last_collider_as_feature = False
        if "context_history_vel" not in current_config.algorithm.train_iterator:
            print("WARNING: context_history_vel not in config. Setting it to False.")
            current_config.algorithm.train_iterator.context_history_vel = False

        if isinstance(current_config.recorder.visualizations, bool):
            del current_config.recorder.visualizations
            current_config.recorder.visualizations = {}
            current_config.recorder.visualizations.enabled = True
            current_config.recorder.visualizations.save_on_disk = True

        if "verbose" not in current_config.algorithm:
            current_config.algorithm.verbose = True


@hydra.main(version_base=None, config_path="../../configs", config_name="evaluation_config")
def evaluate_main(config: ConfigDict) -> None:
    evaluator = Evaluator(evaluation_config=config)
    evaluator.evaluate_experiment()


if __name__ == '__main__':
    # register OmegaConf resolver for hydra
    load_omega_conf_resolvers()
    evaluate_main()
