import re
from typing import Dict, List, Tuple, Any
import os
from utils import env_factory, bound_factory, scheduler_factory, AgentInferenceBuilder
from evaluation import Agent, evaluate
from fire import Fire
from dotenv import load_dotenv
from tqdm import tqdm
import pandas as pd
from tc_mdp import TCMDPFixedAgent, EvalOracleTCMDP, EvalStackedTCMDP
import gymnasium as gym
import numpy as np
import torch
from td3.models import Actor, QNetwork
from td3.td3 import TD3Agent

import warnings

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated.",
)


load_dotenv()


EVAL_SEED = 42


def main(
    all_agents_folder: str,
    output_folder: str,
    nb_episodes: int = 10,
    device: str = "cuda:0",
    verbose: bool = True,
    is_omniscient_adversary: bool = False,
    radius=0.001,
):
    all_results = pd.DataFrame(
        columns=[
            "algo",
            "env_name",
            "uncertainty_dim",
            "seed",
            "rewards_mean",
            "rewards_std",
            "rewards",
        ]
    )
    for algo in tqdm(os.listdir(all_agents_folder), desc="Algo", leave=False):
        if ("vanilla" in algo) & ("tc" not in algo):
            nb_uncertainty_dim = 0
        else:
            nb_uncertainty_dim = 3
        agent_type = (
            "td3"
            if "tc-" in algo or "tc_" in algo
            else ("m2td3" if "m2td3" in algo else "td3")
        )
        algo_folder = os.path.join(all_agents_folder, algo)
        env_names = ["Ant", "HalfCheetah", "Hopper", "HumanoidStandup", "Walker"]
        for env_name in tqdm(env_names, desc="Environment", leave=False):
            agents_folder = os.path.join(algo_folder, env_name, str(nb_uncertainty_dim))
            all_files = os.listdir(agents_folder)
            # keep files that match the pattern adversary_... or agent_...
            all_files = [
                file
                for file in all_files
                if re.match(r"adversary_.*", file) or re.match(r"agent_.*", file)
            ]
            # get the nb of pairs of agents and adversaries
            nb_pairs = len(all_files) // 2
            for i in tqdm(range(nb_pairs), desc="Seed", leave=False):
                agent_name = f"agent_{i}_{env_name}_{nb_uncertainty_dim}.pth"
                agent_path = os.path.join(agents_folder, agent_name)

                eval_env = env_factory(env_name=env_name)
                params_bound: Dict[str, List[float]] = bound_factory(
                    env_name=env_name, nb_dim=3
                )
                if "oracle" in agent_path:
                    eval_env = EvalOracleTCMDP(eval_env, params_bound)
                if "stacked" in agent_path:
                    eval_env = EvalStackedTCMDP(eval_env, params_bound)
                agent = get_agent(
                    agent_path, agent_type, eval_env, nb_uncertainty_dim, device
                )
                eval_env = TCMDPFixedAgent(
                    env=eval_env,
                    agent=agent,
                    params_bound=params_bound,
                    is_omniscient=is_omniscient_adversary,
                    radius=radius,
                )

                adversary_name = f"adversary_{i}_{env_name}_{nb_uncertainty_dim}.pth"
                adversary_path = os.path.join(agents_folder, adversary_name)
                adversary = get_adversary(
                    adversary_path=adversary_path,
                    env=eval_env,
                    device=device,
                    kwargs={},
                )

                results = evaluate_agent(
                    agent=adversary,
                    eval_env=eval_env,
                    nb_episodes=nb_episodes,
                    verbose=verbose,
                )

                results["algo"] = algo
                results["env_name"] = env_name
                results["uncertainty_dim"] = nb_uncertainty_dim
                results["seed"] = i
                all_results = pd.concat([all_results, results], ignore_index=True)

    os.makedirs(output_folder, exist_ok=True)
    all_results.to_csv(os.path.join(output_folder, "results.csv"), index=False)


def get_adversary(
    adversary_path, env: gym.Env, device: torch.device, kwargs: dict[str, Any]
):
    """
    Initialize a TD3 agent.

    Args:
        env (gym.Env): The environment to train on.
        device (torch.device): The device to use for training.
        kwargs (dict[str, Any]): Additional arguments for the TD3 agent.
    """
    action_space = env.action_space
    obs_space = env.observation_space

    observation_dim: int = np.prod(obs_space.shape)
    action_dim: int = np.prod(action_space.shape)

    actor = Actor(observation_dim=observation_dim, action_space=action_space).to(device)
    actor.load_state_dict(
        state_dict=torch.load(adversary_path, map_location=device)["actor"]
    )

    qf1 = QNetwork(observation_dim=observation_dim, action_dim=action_dim).to(device)
    qf2 = QNetwork(observation_dim=observation_dim, action_dim=action_dim).to(device)
    adversary = TD3Agent(actor, qf1, qf2, device=device, **kwargs)
    return adversary


def mean_and_std(rewards: List[float]) -> Tuple[float, float]:
    mean = sum(rewards) / len(rewards)
    std = (sum([(reward - mean) ** 2 for reward in rewards]) / len(rewards)) ** 0.5
    return mean, std


def get_agent(agent_path, agent_type, env, nb_uncertainty_dim, device):
    agent_builder = AgentInferenceBuilder(
        env=env, nb_dim=nb_uncertainty_dim, device=device
    )
    agent: Agent = (
        agent_builder.add_actor_path(path=agent_path)
        .add_device(device)
        .add_agent_type(agent_type)
        .build()
    )

    return agent


def evaluate_agent(
    agent,
    eval_env,
    nb_episodes,
    verbose,
):
    """
    Given an agent, evaluate it in an environment with different scheduler types and return the results as a DataFrame.

    Args:
        agent (Agent): The agent to evaluate.
        env_name (str): Name of the environment.
        nb_uncertainty_dim (int): Number of uncertainty dimensions.
        nb_episodes (int): Number of episodes to run for each scheduler type.
        verbose (bool): Whether to print verbose output.
    """
    result_df = pd.DataFrame(
        columns=["scheduler_type", "rewards_mean", "rewards_std", "rewards"]
    )

    rewards = evaluate(
        env=eval_env, agent=agent, seed=EVAL_SEED, num_episodes=nb_episodes
    )
    mean_reward, std_reward = mean_and_std(rewards)
    result_df = pd.DataFrame(
        [
            {
                "rewards_mean": mean_reward,
                "rewards_std": std_reward,
                "rewards": rewards,
            }
        ]
    )
    if verbose:
        print(f"Rewards rewards: {rewards}")

    return result_df


if __name__ == "__main__":
    Fire(main)
