from __future__ import annotations

from typing import NamedTuple, Protocol

import gymnasium as gym
import numpy as np
import numpy.typing as npt

from metaworld.env_dict import ALL_V3_ENVIRONMENTS

extra_time_for_compile = True

class Agent(Protocol):
    def eval_action(
        self, observations: npt.NDArray[np.float64]
    ) -> npt.NDArray[np.float64]:
        ... 

    def reset(self, env_mask: npt.NDArray[np.bool_]) -> None:
        ...


class MetaLearningAgent(Agent, Protocol):
    def init(self) -> None:
        ...

    def adapt_action(
        self, observations: npt.NDArray[np.float64]
    ) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]:
        ...

    def step(self, timestep: Timestep) -> None:
        ...

    def adapt(self) -> None:
        ...


def _get_task_names(
    envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv,
) -> list[str]:
    metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()}
    return [
        metaworld_cls_to_task_name[task_name]
        for task_name in envs.get_attr("task_name")
    ]


def evaluation(
    agent: Agent,
    eval_envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv,
    num_episodes: int = 50,
) -> tuple[float, float, dict[str, float], dict[str, list[float]],
            dict[str, int], dict[str, float], dict[str, float]]:
    terminate_on_success = np.all(eval_envs.get_attr("terminate_on_success")).item()
    eval_envs.call("toggle_terminate_on_success", True)

    obs: npt.NDArray[np.float64]
    obs, _ = eval_envs.reset()
    agent.reset(np.ones(eval_envs.num_envs, dtype=np.bool_))


    task_names = _get_task_names(eval_envs)
    successes = {task_name: 0 for task_name in set(task_names)}
    predictive_t_error = {task_name: [] for task_name in set(task_names)}
    predictive_r_error = {task_name: [] for task_name in set(task_names)}

    episodic_returns: dict[str, list[float]] = {
        task_name: [] for task_name in set(task_names)
    }

    def eval_done(returns):
        return all(len(r) >= num_episodes for _, r in returns.items())


    while not eval_done(episodic_returns):

        actions = agent.eval_action(obs)
        next_obs, reward, terminations, truncations, infos = eval_envs.step(actions)
        agent.step(Timestep(obs, actions, reward, None, None, None))

        t_loss, r_loss = agent.predictive_losses(next_obs, reward)
        obs = next_obs
        dones = np.logical_or(terminations, truncations)
        agent.reset(dones)

        for i, env_ended in enumerate(dones):
            if not env_ended:
                predictive_t_error[task_names[i]].append(t_loss[i])
                predictive_r_error[task_names[i]].append(r_loss[i])

            else:
                episodic_returns[task_names[i]].append(
                    float(infos["final_info"]["episode"]["r"][i])
                )
                if len(episodic_returns[task_names[i]]) <= num_episodes:
                    successes[task_names[i]] += int(infos["final_info"]["success"][i])

    episodic_returns = {
        task_name: returns[:num_episodes]
        for task_name, returns in episodic_returns.items()
    }

    success_rate_per_task = {
        task_name: task_successes / num_episodes
        for task_name, task_successes in successes.items()
    }
    mean_success_rate = np.mean(list(success_rate_per_task.values()))
    mean_returns = np.mean(list(episodic_returns.values()))

    error_count_per_task = {task_name: len(errors) for task_name, errors in predictive_t_error.items()}
    total_error_t_per_task = {task_name: sum(errors) for task_name, errors in predictive_t_error.items()}
    total_error_r_per_task = {task_name: sum(errors) for task_name, errors in predictive_r_error.items()}

    eval_envs.call("toggle_terminate_on_success", terminate_on_success)

    return (
        float(mean_success_rate),
        float(mean_returns),
        success_rate_per_task,
        episodic_returns,
        error_count_per_task,
        total_error_t_per_task,
        total_error_r_per_task
    )


def metalearning_evaluation(
    agent: MetaLearningAgent,
    eval_envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv,
    num_evals: int = 10,  # Assuming 40 goals per test task and meta batch size of 20
    adaptation_steps: int = 1,
    adaptation_episodes: int = 10,
    evaluation_episodes: int = 3,
) -> tuple[dict[int, float], dict[int, float], dict[int, dict[str, float]],
           dict[int, float], dict[int, float], dict[int, dict[str, float]],
           dict[int, float]]:
    
    # test the performance from zero-shot to few-shot.
    eval_envs.call("toggle_sample_tasks_on_reset", False)
    eval_envs.call("toggle_terminate_on_success", False)
    task_names = _get_task_names(eval_envs)

    mean_success_rate_dict = {i:0.0 for i in range(adaptation_steps + 1)}
    mean_return_dict = {i:0.0 for i in range(adaptation_steps + 1)}

    mean_t_loss_dict = {i:0.0 for i in range(adaptation_steps + 1)}
    mean_r_loss_dict = {i:0.0 for i in range(adaptation_steps + 1)}
    mean_t_per_task_dict = {}
    mean_r_per_task_dict = {}

    total_t_loss_dict = {i:[0.0, 0] for i in range(adaptation_steps + 1)}
    total_r_loss_dict = {i:[0.0, 0] for i in range(adaptation_steps + 1)}

    t_total_loss_per_task = {i:np.zeros((len(set(task_names)), 2)) for i in range(adaptation_steps + 1)}
    r_total_loss_per_task = {i:np.zeros((len(set(task_names)), 2)) for i in range(adaptation_steps + 1)}

    success_rate_per_task = {i:np.zeros((num_evals, len(set(task_names)))) for i in range(adaptation_steps + 1)}
    success_rate_per_task_dict = {}

    for i in range(num_evals):
        obs: npt.NDArray[np.float64]
        eval_envs.call("sample_tasks")
        agent.init()
        
        for step in range(adaptation_steps + 1):
            obs, _ = eval_envs.reset()
            episodes_elapsed = np.zeros((eval_envs.num_envs,), dtype=np.uint16)
            episode_to_adapt = adaptation_episodes if step > 0 else -1
            while not (episodes_elapsed >= episode_to_adapt).all():
                actions, aux_policy_outs = agent.adapt_action(obs)
                next_obs, rewards, terminations, truncations, _ = eval_envs.step(
                    actions
                )
                agent.step(
                    Timestep(
                        obs,
                        actions,
                        rewards,
                        terminations,
                        truncations,
                        aux_policy_outs,
                    )
                )
                episodes_elapsed += np.logical_or(terminations, truncations)
                obs = next_obs
 
            agent.adapt()
            print(f"Finished Adapt Step {step}, {episode_to_adapt} Episodes Adapted.")
            # Evaluate every time we adapt
            mean_success_rate, mean_return, _success_rate_per_task, _, \
                error_count_per_task, total_t_loss_per_task, total_r_loss_per_task  = evaluation(
                agent, eval_envs, evaluation_episodes
            )
            mean_success_rate_dict[step] += mean_success_rate
            mean_return_dict[step] += mean_return
            success_rate_per_task[step][i] = np.array(list(_success_rate_per_task.values()))
            
            total_t = 0
            total_r = 0
            counter = 0
            
            for j, task_name in enumerate(set(task_names)):
                total_t += total_t_loss_per_task[task_name]
                total_r += total_r_loss_per_task[task_name]
                counter += error_count_per_task[task_name]

                t_total_loss_per_task[step][j] += np.array([total_t_loss_per_task[task_name], 
                                                              error_count_per_task[task_name]])
                r_total_loss_per_task[step][j] += np.array([total_r_loss_per_task[task_name], 
                                                              error_count_per_task[task_name]])

            total_t_loss_dict[step][0] += total_t
            total_t_loss_dict[step][1] += counter

            total_r_loss_dict[step][0] += total_r
            total_r_loss_dict[step][1] += counter
            
    t_losses_per_task = {
        step: t_total_loss_per_task[step][:, 0] / t_total_loss_per_task[step][:, 1] for step in range(adaptation_steps + 1)
    }
    r_losses_per_task = {
        step: r_total_loss_per_task[step][:, 0] / r_total_loss_per_task[step][:, 1] for step in range(adaptation_steps + 1)
    }

    for step in range(adaptation_steps + 1):
        mean_success_rate_dict[step] = mean_success_rate_dict[step] / num_evals
        mean_return_dict[step] = mean_return_dict[step] / num_evals

        success_rates = success_rate_per_task[step].mean(axis=0)
        task_success_rates = {
            task_name: success_rates[i] for i, task_name in enumerate(set(task_names))
        }
        success_rate_per_task_dict[step] = task_success_rates

        mean_t_loss_dict[step] = total_t_loss_dict[step][0] / total_t_loss_dict[step][1]
        mean_r_loss_dict[step] = total_r_loss_dict[step][0] / total_r_loss_dict[step][1]

        mean_t_per_task_dict[step] = {task_name: t_losses_per_task[step][i] for i, task_name in enumerate(set(task_names))}
        mean_r_per_task_dict[step] = {task_name: r_losses_per_task[step][i] for i, task_name in enumerate(set(task_names))}


    return (
        mean_success_rate_dict,
        mean_return_dict,
        success_rate_per_task_dict,
        mean_t_loss_dict,
        mean_r_loss_dict,
        mean_t_per_task_dict,
        mean_r_per_task_dict
    )

class Timestep(NamedTuple):
    observation: npt.NDArray
    action: npt.NDArray
    reward: npt.NDArray
    terminated: npt.NDArray
    truncated: npt.NDArray
    aux_policy_outputs: dict[str, npt.NDArray]
