import warnings
from collections import defaultdict
from typing import Tuple

import torch
from torch_geometric.data import Data
from tqdm import tqdm

from ltsgns_mp.algorithms import AbstractAlgorithm
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ValueDict


class Evaluator:
    def __init__(self, config, algorithm: AbstractAlgorithm, env, eval_only: bool):
        self.config = config
        self.algorithm = algorithm
        self.env = env
        self._eval_only = eval_only
        if eval_only:
            self.context_sizes = config.context_test_sizes
        else:
            self.context_sizes = config.context_val_sizes
        self.current_best_eval_loss = {context_size: float("inf") for context_size in self.context_sizes}
        self._final_eval_dict = None

    def eval_step(self, epoch, force_eval=False, visualize_only: bool = False) -> ValueDict:
        """
        Performs one evaluation step, if the evaluation interval is reached.

        Returns:

        """
        if not force_eval:
            if epoch % self.config.eval_interval != 0 or (epoch == 0 and not self.config.initial_eval):
                return {}
        # big dictionary where all different context size results get stored.
        self._initialize_final_eval_dict()
        if not self.verbose:
            print(f"Evaluating epoch {epoch} ...")
        for context_size in tqdm(self.context_sizes, desc=f"Evaluating epoch {epoch} ...", disable=not self.verbose):
            self.evaluate_context_size(context_size, visualize_only)
        self.check_early_stopping()
        return self._final_eval_dict

    def evaluate_context_size(self, context_size, visualize_only):
        self.algorithm.simulator.eval()
        eval_iterator = self.env.eval_iterators[context_size]
        eval_dict, visualization_dict = self.initialize_eval_dict()
        if not self.verbose:
            print(f"Evaluating context size {context_size} ...")
        for idx, eval_traj in tqdm(enumerate(eval_iterator), total=len(eval_iterator),
                                   desc=f"Evaluating context size {context_size} ...", disable=not self.verbose):
            visualize = idx in self.config.animation_indices
            if visualize_only and not visualize:
                continue
            # apply the algorithm to the batch
            predicted_trajectory, additional_visualizations = self.algorithm.predict_trajectory(eval_traj,
                                                                                                visualize=visualize,
                                                                                                eval_only=self._eval_only)
            evaluation_results = self.evaluate_trajectory(predicted_trajectory, eval_traj)
            eval_dict = self.update_eval_dict(eval_dict, evaluation_results)

            # check if this trajectory was meant to be visualized
            if visualize:
                # save it in the eval_dict
                visualization_dict[f"Task_{idx}"]["to_visualize"] = {"eval_traj": eval_traj,
                                                                     "predicted_traj": predicted_trajectory}
                # add the additional visualizations
                visualization_dict[f"Task_{idx}"].update(additional_visualizations)
        self.finalize_eval_dict(eval_dict, visualization_dict, context_size)

    def evaluate_trajectory(self, predicted_trajectory: torch.Tensor, eval_traj: Data) -> ValueDict:
        """
        Evaluates the predicted trajectory against the ground truth trajectory.
        Returns: The evaluation results as a ValueDict.

        """
        evaluation_results = {}
        # get the correct evaluation interval of ground truth and predicted trajectory + unsqueeze
        eval_indices = eval_traj.evaluation_indices[0]
        gth_mesh_position = eval_traj[keys.CONTEXT_NODE_POSITIONS][0][eval_indices]
        predicted_mesh_positions = predicted_trajectory[eval_indices]
        assert gth_mesh_position.shape == predicted_mesh_positions.shape, "Shapes of ground truth and predicted trajectory do not match."

        for metric in self.config.metric:
            for evaluation_type in self.config.evaluation_type:
                for time_interval in self.config.time_interval:
                    try:
                        eval_result = self._eval_single_metric(predicted_mesh_positions,
                                                               gth_mesh_position,
                                                               time_interval,
                                                               evaluation_type,
                                                               metric)
                        evaluation_results[self._get_name(time_interval, evaluation_type, metric)] = eval_result
                    except IndexError:
                        warnings.warn(f"Could not evaluate {self._get_name(time_interval, evaluation_type, metric)}.")
        return evaluation_results

    def _eval_single_metric(self, predicted_mesh_positions, gth_mesh_positons, time_interval, evaluation_type,
                            metric) -> float:
        """
        Evaluates a single metric for a single time interval.

        Returns: The evaluation result as a Float.

        """
        if time_interval >= len(gth_mesh_positons):
            raise IndexError("Time interval {} is out of bounds for ground truth trajectory of length {}.".format(
                time_interval, len(gth_mesh_positons)))
        if evaluation_type == "last":
            gth = gth_mesh_positons[time_interval]
            prediction = predicted_mesh_positions[time_interval]
        elif evaluation_type == "mean":
            if time_interval == -1:
                gth = gth_mesh_positons
                prediction = predicted_mesh_positions
            else:
                gth = gth_mesh_positons[:time_interval + 1]
                prediction = predicted_mesh_positions[:time_interval + 1]
        else:
            raise ValueError("Unknown evaluation type: {}".format(evaluation_type))
        if metric == "mse":
            return mse(gth, prediction).item()
        else:
            raise ValueError("Unknown metric: {}".format(metric))

    def _get_name(self, time_interval, evaluation_type, metric):
        if time_interval == -1:
            time_name = "full_rollout"
        else:
            time_name = f"{time_interval}_steps"
        return f"{time_name}_{evaluation_type}_{metric}"

    def _initialize_final_eval_dict(self):
        self._final_eval_dict = {
            keys.SCALARS: {},
            keys.VISUALIZATIONS: {}
        }

    def initialize_eval_dict(self) -> Tuple[ValueDict, ValueDict]:
        """
        Initializes the evaluation dict.
        Returns: The initialized evaluation dict.

        """
        eval_dict = defaultdict(list)
        visualization_dict = defaultdict(dict)
        return eval_dict, visualization_dict

    def update_eval_dict(self, eval_dict: ValueDict, evaluation_results: ValueDict) -> ValueDict:
        """
        Updates the evaluation dict with the evaluation results of a single batch.
        Returns: The updated evaluation dict.

        """
        for key, value in evaluation_results.items():
            eval_dict[key].append(value)
        return eval_dict

    def finalize_eval_dict(self, eval_dict: ValueDict, visualization_dict: ValueDict, context_size: int) -> ValueDict:
        """
        Finalizes the evaluation dict. Aggregates the evaluation results over all batches.
        Returns: The finalized evaluation dict.

        """
        for key, value in eval_dict.items():
            eval_dict[key] = sum(value) / len(value)

        # convert it into a standard dict
        eval_dict = dict(eval_dict)
        visualization_dict = dict(visualization_dict)
        # put everything into the final dict
        self._final_eval_dict[keys.SCALARS][self.get_context_size_str(context_size)] = eval_dict
        self._final_eval_dict[keys.VISUALIZATIONS][self.get_context_size_str(context_size)] = visualization_dict

    def check_early_stopping(self):
        """
        Checks if the majority of the current evaluation losses (depending on the context size) is  in general better than the current best evaluation loss.
        This is decided by a majority vote.
        If so, it updates the current best evaluation loss and sets the save_checkpoint_this_epoch flag to True.
        Returns:

        """
        num_improvements = 0
        for context_size in self.context_sizes:
            eval_dict = self._final_eval_dict[keys.SCALARS][self.get_context_size_str(context_size)]
            current_eval_loss = eval_dict[self.config.early_stopping_metric]
            if current_eval_loss < self.current_best_eval_loss[context_size]:
                num_improvements += 1

        if num_improvements / len(self.context_sizes) > self.config.early_stopping_majority_needed:
            self.algorithm.save_checkpoint_this_epoch = True
            for context_size in self.context_sizes:
                eval_dict = self._final_eval_dict[keys.SCALARS][self.get_context_size_str(context_size)]
                current_eval_loss = eval_dict[self.config.early_stopping_metric]
                self.current_best_eval_loss[context_size] = current_eval_loss

    def get_context_size_str(self, context_size) -> str:
        return f"eval_context_size_{context_size:03}"

    @property
    def verbose(self):
        return self.algorithm.config.verbose
