import copy
import math
import parser
import random
from typing import List, Tuple

from tqdm import tqdm

from global_utils import GeneralUtils
from plotter import Plotter


class MetricCollector:
    def __init__(self, policy_indices, j_table, algorithm_name: str, k: int):
        self.toolbox = GeneralUtils()
        (
            self.env_class,
            self.policy_class,
            self.datasets,
        ) = self.toolbox.load_env_policy_dataset()

        self.policy_indices = policy_indices
        self.algorithm_name = algorithm_name
        dataset_indices = parser.args.ground_truth_for_traverse
        self.v_functions_offline = {
            i: {
                j: {t: self.toolbox.load_v_function(i, j, t) for t in dataset_indices}
                for j in policy_indices
            }
            for i in range(len(self.env_class))
        }
        # print(self.v_functions_offline)
        self.plotter = Plotter()

        self.j_tables = j_table
        self.plot_j_table()
        # print(self.j_tables)
        self.k = k
        self.top_k_hits = {i: {j: 0 for j in policy_indices} for i in dataset_indices}
        self.j_error_table = {
            i: {j: 0 for j in policy_indices} for i in dataset_indices
        }
        self.normalizer = None

    def collect_initial_states(self, num_states):
        initial_states = []
        for _ in range(num_states):
            e = random.choice(self.env_class)
            obs, _ = e.reset()
            qpos = copy.deepcopy(e.data.qpos[:])
            qvel = copy.deepcopy(e.data.qvel[:])
            initial_states.append([obs, qpos, qvel])
        return initial_states

    def plot_j_table(self):
        for target in tqdm(self.j_tables):
            j_table = list(self.j_tables[target].values())
            policy_indices = list(self.j_tables[target].keys())
            self.plotter.plot_j_table_with_pi(policy_indices, j_table, target)
        policy_indices = self.policy_indices
        for target in tqdm(policy_indices):
            j_table = [self.j_tables[i][target] for i in self.j_tables]
            env_indices = list(self.j_tables.keys())
            self.plotter.plot_j_table_with_env(env_indices, j_table, target)

    def update_top_k_indicator(
            self,
            env_ground_truth: int,
            ranked_env_indices: List[int],
            policy_index: int,
    ):
        self.top_k_hits[env_ground_truth][policy_index] = (
            1 if env_ground_truth in ranked_env_indices[: self.k] else 0
        )

    def update_j_error(
            self,
            env_ground_truth: int,
            ranked_env_indices: List[int],
            policy_index: int,
    ):
        """
        Given a ground-truth environment indexed as M_i, a policy indexed as π_j and a ranked MDP class
        {M_{A_1}, ..., M_{A_L}} (where [A_1, ..., A_L] is the ranking generated by some selection algorithm A),
        report the corresponding J-error ||J_{M_i}(π_j)-J_{A_1}(π_j)||.
        """
        top_k = (
            parser.args.top_k
            if self.algorithm_name != "trivial_random"
            else len(ranked_env_indices)
        )
        self.j_error_table[env_ground_truth][policy_index] = (
                sum(
                    math.sqrt(
                        (
                                self.j_tables[env_ground_truth][policy_index]
                                - self.j_tables[selected][policy_index]
                        )
                        ** 2
                        / self.normalize_j_error_by_variance(target=env_ground_truth)
                    )
                    for selected in ranked_env_indices[:top_k]
                )
                / top_k
        )

    def normalize_j_error_by_variance(self, target):
        policy_indices = self.policy_indices
        if len(policy_indices) == 1:
            # do not normalize if only one single policy is involved
            normalizer = 1.0
        else:
            expectation = sum(self.j_tables[target][j] for j in policy_indices) / len(
                policy_indices
            )
            normalizer = sum(
                (self.j_tables[target][j] - expectation) ** 2 for j in policy_indices
            ) / len(policy_indices)
        text_normalizer = f" (with normalizer={math.sqrt(normalizer)}) "
        if text_normalizer not in parser.text_for_comparison:
            parser.text_for_comparison = text_normalizer + parser.text_for_comparison
        return normalizer

    def take_average_on_policy_class(self, env_index: int) -> Tuple[float, float]:
        policy_indices = self.policy_indices
        return (
            sum(self.top_k_hits[env_index][j] for j in policy_indices)
            / len(policy_indices),
            sum(self.j_error_table[env_index][j] for j in policy_indices)
            / len(policy_indices),
        )

    def fetch_metric(self):
        return self.top_k_hits, self.j_error_table
