#  Copyright (c) 2024-2025
import time

import matplotlib
import torch
import wandb
from tensordict import unravel_key
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torch.distributions import Categorical

from torchrl._utils import logger as torchrl_logger
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from utils.callback import FootBallHeuristicCurriculum
from utils.gen_agg import plot_gen_agg
from utils.logging_tools import init_logging, log_evaluation, log_training, log_iter
from utils.utils import standardize


def swap_last(source, dest):
    source = unravel_key(source)
    dest = unravel_key(dest)
    if isinstance(source, str):
        if isinstance(dest, str):
            return dest
        return dest[-1]
    if isinstance(dest, str):
        return source[:-1] + (dest,)
    return source[:-1] + (dest[-1],)


def rendering_callback(env, td):
    env.frames.append(env.render(mode="rgb_array", visualize_when_rgb=False))


def create_agent_nets(shared_params: bool, env, cfg):
    # Policy

    if cfg.env.continuous_actions:
        n_outputs = 2 * env.action_spec.shape[-1]
    else:
        n_outputs = env.unbatched_action_spec[("agents", "action")].space.n

    if not cfg.model.reset_with_env:
        net = MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
            n_agent_outputs=n_outputs,
            n_agents=env.n_agents,
            centralised=False,
            share_params=shared_params,
            device=cfg.train.device,
            depth=2,
            num_cells=256,
            activation_class=nn.Tanh,
        )
    else:
        net = nn.Sequential(
            MultiAgentMLP(
                n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
                n_agent_outputs=16,
                n_agents=env.n_agents,
                centralised=False,
                share_params=shared_params,
                device=cfg.train.device,
                depth=2,
                num_cells=256,
                activation_class=nn.Tanh,
            ),
            MultiAgentMLP(
                n_agent_inputs=16,
                n_agent_outputs=n_outputs,
                n_agents=env.n_agents,
                centralised=False,
                share_params=shared_params,
                device=cfg.train.device,
                depth=1,
                num_cells=16,
                activation_class=nn.Tanh,
            ),
        )

    if cfg.env.continuous_actions:
        actor_net = nn.Sequential(
            net,
            NormalParamExtractor("biased_softplus_1.0"),
        )
        policy_module = TensorDictModule(
            actor_net,
            in_keys=[("agents", "observation")],
            out_keys=[("agents", "loc"), ("agents", "scale")],
        )
        policy = ProbabilisticActor(
            module=policy_module,
            spec=env.unbatched_action_spec,
            in_keys=[("agents", "loc"), ("agents", "scale")],
            out_keys=[env.action_key],
            distribution_class=TanhNormal,
            distribution_kwargs={
                "min": env.unbatched_action_spec[("agents", "action")].space.low,
                "max": env.unbatched_action_spec[("agents", "action")].space.high,
            },
            return_log_prob=True,
        )
    else:
        policy_module = TensorDictModule(
            net,
            in_keys=[("agents", "observation")],
            out_keys=[("agents", "logits")],
        )
        policy = ProbabilisticActor(
            module=policy_module,
            spec=env.unbatched_action_spec,
            in_keys=[("agents", "logits")],
            out_keys=[("agents", "action")],
            distribution_class=Categorical,
            return_log_prob=True,
        )

    # Critic
    module = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=1,
        n_agents=env.n_agents,
        centralised=cfg.model.centralised_critic,
        share_params=shared_params,
        device=cfg.train.device,
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )
    value_module = ValueOperator(
        module=module,
        in_keys=[("agents", "observation")],
    )
    return policy, value_module


def reset_child_params(module):
    for layer in module.children():
        if hasattr(layer, "reset_parameters"):
            print("reset")
            layer.reset_parameters()
        reset_child_params(layer)


def create_loss(policy, value_module, env, cfg):
    # Loss
    loss_module = ClipPPOLoss(
        actor_network=policy,
        critic_network=value_module,
        clip_epsilon=cfg.loss.clip_epsilon,
        entropy_coef=cfg.loss.entropy_eps,
        normalize_advantage=False,
        samples_mc_entropy=10,
    )
    loss_module.set_keys(
        reward=env.reward_key,
        action=env.action_key,
        done=("agents", "done"),
        terminated=("agents", "terminated"),
    )
    loss_module.make_value_estimator(
        ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda
    )
    return loss_module


def create_buffer(cfg):
    replay_buffer = TensorDictReplayBuffer(
        storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device),
        sampler=SamplerWithoutReplacement(),
        batch_size=cfg.train.minibatch_size,
    )
    return replay_buffer


def get_training_routine(cfg):
    if cfg.train.training_routine == "concurrent":
        return concurrent_training_routine(cfg)
    elif cfg.train.training_routine == "alternated":
        return alternated_training_routine(cfg)
    else:
        assert False


def concurrent_training_routine(cfg):
    def env_trains(iter):
        if (
            iter < cfg.collector.n_iters_env
            and iter % cfg.collector.env_optim_interval == 0
        ):
            return True
        return False

    def agents_train(iter):
        return True

    return env_trains, agents_train


def alternated_training_routine(cfg):
    n_agents_iters = 100
    n_env_iters = 50
    cycle_len = n_agents_iters + n_env_iters

    def env_trains(iter):
        if iter < cfg.collector.n_iters_env and iter % cycle_len < n_env_iters:
            return True
        return False

    def agents_train(iter):
        if iter % cycle_len >= n_env_iters:
            return True
        return False

    return env_trains, agents_train


def train(cfg: "DictConfig"):  # noqa: F821
    # Device
    cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
    cfg.env.device = cfg.train.device

    # Seeding
    torch.manual_seed(cfg.seed)

    # Sampling
    cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps
    cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters
    cfg.buffer.memory_size = cfg.collector.frames_per_batch

    # Create env and env_test
    env = VmasEnv(
        scenario=cfg.env.scenario_name,
        num_envs=cfg.env.vmas_envs,
        continuous_actions=cfg.env.continuous_actions,
        max_steps=cfg.env.max_steps,
        device=cfg.env.device,
        seed=cfg.seed,
        grad_enabled=cfg.env.scenario_name != "soccer_design",
        # Scenario kwargs
        **cfg.env.scenario,
    )
    env = TransformedEnv(
        env,
        RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
    )

    env_test = VmasEnv(
        scenario=cfg.env.scenario_name,
        num_envs=cfg.eval.evaluation_episodes,
        continuous_actions=cfg.env.continuous_actions,
        max_steps=cfg.env.max_steps,
        device=cfg.env.device,
        seed=cfg.seed,
        grad_enabled=False,
        # Scenario kwargs
        **cfg.env.scenario,
    )
    callbacks = (
        [
            # FootBallHeuristicCurriculum(
            #     iter_add_adversary=100,
            #     ai_strength_start=0,
            #     n_iters_annealing=300,
            #     ai_strength_end=1,
            #     nonlinear_annealing=True,
            #     env=env,
            #     test_env=env_test,
            # )
            FootBallHeuristicCurriculum(
                iter_add_adversary=100,
                ai_strength_start=0,
                n_iters_annealing=100,
                ai_strength_end=0.5,
                nonlinear_annealing=True,
                env=env,
                test_env=env_test,
                iter_disable_beta=300,
            )
        ]
        if cfg.env.scenario_name == "soccer_design"
        else []
    )
    hom_actor, hom_critic = create_agent_nets(
        shared_params=True,
        env=env,
        cfg=cfg,
    )
    het_actor, het_critic = create_agent_nets(
        shared_params=False,
        env=env,
        cfg=cfg,
    )

    het_loss = create_loss(het_actor, het_critic, env, cfg)
    hom_loss = create_loss(hom_actor, hom_critic, env, cfg)

    het_buffer = create_buffer(cfg)
    hom_buffer = create_buffer(cfg)

    # Logging
    if cfg.logger.backend:
        model_name = ("MA" if cfg.model.centralised_critic else "I") + "PPO"
        logger = init_logging(cfg, model_name)

    if len(list(env.scenario.parameters())):
        for parameter, parameter_test in zip(
            env.scenario.parameters(), env_test.scenario.parameters()
        ):
            parameter_test.data = parameter.data
        env_optimizer = torch.optim.Adam(env.scenario.parameters(), cfg.train.env_lr)
        env_trains, agents_train = get_training_routine(cfg)
    else:
        cfg.collector.n_iters_env = 0

        def env_trains(iter):
            return False

        def agents_train(iter):
            return True

    het_optim = torch.optim.Adam(
        het_loss.parameters(), cfg.train.lr, eps=cfg.train.adam_eps
    )
    hom_optim = torch.optim.Adam(
        hom_loss.parameters(), cfg.train.lr, eps=cfg.train.adam_eps
    )

    total_time = 0
    total_frames = 0
    sampling_start = time.time()

    for iter in range(cfg.collector.n_iters):
        for callback in callbacks:
            callback.on_iter(iter, env, env_test, logger)
        if iter == cfg.collector.n_iters_env or iter == 0:
            for model in [het_actor, hom_actor, het_critic, hom_critic]:
                model.reset_parameters_recursive()

        torchrl_logger.info(f"\nIteration {iter}")

        # Collection
        with set_exploration_type(ExplorationType.RANDOM):
            het_collection_td = env.rollout(
                max_steps=cfg.env.max_steps,
                policy=het_actor,
                break_when_any_done=False,
            )
            hom_collection_td = env.rollout(
                max_steps=cfg.env.max_steps,
                policy=hom_actor,
                break_when_any_done=False,
            )

        sampling_time = time.time() - sampling_start

        # Compute gap
        het_episode_reward = het_collection_td.get(
            ("next", "agents", "episode_reward")
        )  # (o_t+1,r+_t+1) just print it and see
        hom_episode_reward = hom_collection_td.get(("next", "agents", "episode_reward"))
        regret = (het_episode_reward - hom_episode_reward).mean()
        # ratio = (
        #     het_episode_reward
        #     / torch.where(
        #         hom_episode_reward == 0,
        #         torch.sign(het_episode_reward) * 1e-8,
        #         hom_episode_reward,
        #     )
        # ).mean()

        log_env_stuff(
            cfg, regret, het_collection_td, hom_collection_td, logger, env, iter
        )

        if env_trains(iter):
            train_env(cfg, regret, env, env_optimizer, logger)

            if cfg.model.reset_with_env:
                het_actor.module[0].module[0][1].reset_parameters()
                hom_actor.module[0].module[0][1].reset_parameters()

        current_frames = (hom_collection_td.numel() + het_collection_td.numel()) / 2
        total_frames += current_frames

        training_times = [0, 0]
        eval_returns = []
        for het in [True, False]:
            loss_module = het_loss if het else hom_loss
            tensordict_data = het_collection_td if het else hom_collection_td
            policy = het_actor if het else hom_actor
            optim = het_optim if het else hom_optim
            replay_buffer = het_buffer if het else hom_buffer

            if agents_train(iter):
                training_time = train_agents(
                    cfg,
                    het,
                    tensordict_data,
                    env,
                    loss_module,
                    replay_buffer,
                    optim,
                    logger,
                    iter,
                )
                training_times[0 if het else 1] = training_time

            if (
                cfg.eval.evaluation_episodes > 0
                and iter % cfg.eval.evaluation_interval == 0
                and cfg.logger.backend
            ):
                returns = evaluate_agents(cfg, policy, env_test, logger, het)
                eval_returns.append(returns)
        if len(eval_returns):
            logger.experiment.log(
                {"eval_regret": eval_returns[0] - eval_returns[1]},
                commit=False,
            )
        iteration_time = sampling_time + sum(training_times)
        total_time += iteration_time
        log_iter(
            logger=logger,
            sampling_time=sampling_time,
            total_time=total_time,
            iteration=iter,
            current_frames=current_frames,
            iteration_time=iteration_time,
            total_frames=total_frames,
            step=iter,
        )

        if cfg.logger.backend == "wandb":
            logger.experiment.log({}, commit=True)
        sampling_start = time.time()
    wandb.finish()


def train_env(cfg, regret, env, env_optimizer, logger):
    total_norms = []
    for _ in range(cfg.train.num_epochs_env):
        loss = -regret
        loss.backward()

        total_norm = torch.nn.utils.clip_grad_norm_(
            env.scenario.parameters(), cfg.train.max_grad_norm
        )
        total_norms.append(total_norm)

        env_optimizer.step()
        env_optimizer.zero_grad()
    logger.experiment.log(
        {"env_grad_norm": torch.stack(total_norms, dim=0).mean()},
        commit=False,
    )


def train_agents(
    cfg,
    het,
    tensordict_data,
    env,
    loss_module,
    replay_buffer,
    optim,
    logger,
    iter,
):
    agent_group = "het" if het else "hom"
    tensordict_data = tensordict_data.detach()
    for done_key in env.done_keys[:1]:
        new_name = swap_last(env.reward_key, done_key)
        tensordict_data.set(
            ("next", new_name),
            tensordict_data.get(("next", done_key))
            .unsqueeze(-1)
            .expand(tensordict_data.get(("next", env.reward_key)).shape),
        )
    with torch.no_grad():
        loss_module.value_estimator(
            tensordict_data,
            params=loss_module.critic_network_params,
            target_params=loss_module.target_critic_params,
        )
        advantage = tensordict_data.get(loss_module.tensor_keys.advantage)
        if cfg.loss.normalize_advantage and advantage.numel() > 1:
            advantage = standardize(advantage, exclude_dims=[-2])
            tensordict_data.set(loss_module.tensor_keys.advantage, advantage)

    data_view = tensordict_data.reshape(-1)
    replay_buffer.extend(data_view)

    training_tds = []
    training_start = time.time()
    for _ in range(cfg.train.num_epochs):
        for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size):
            subdata = replay_buffer.sample()
            loss_vals = loss_module(subdata)
            training_tds.append(loss_vals.detach())

            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            loss_value.backward()

            total_norm = torch.nn.utils.clip_grad_norm_(
                loss_module.parameters(),
                cfg.train.max_grad_norm,
                error_if_nonfinite=True,
            )
            training_tds[-1].set("grad_norm", total_norm.mean())

            optim.step()
            optim.zero_grad()

    training_time = time.time() - training_start
    training_tds = torch.stack(training_tds)

    # More logs
    if cfg.logger.backend:
        log_training(
            logger,
            training_tds,
            tensordict_data,
            training_time,
            step=iter,
            agent_group=agent_group,
        )
    return training_time


def evaluate_agents(cfg, policy, env_test, logger, het):
    agent_group = "het" if het else "hom"
    evaluation_start = time.time()
    with torch.no_grad(), set_exploration_type(
        ExplorationType.RANDOM if cfg.eval.explore else ExplorationType.MODE
    ):
        env_test.frames = []
        rollouts = env_test.rollout(
            max_steps=cfg.env.max_steps,
            policy=policy,
            callback=rendering_callback,
            break_when_any_done=False,
            # We are running vectorized evaluation we do not want it to stop when just one env is done
        )

        evaluation_time = time.time() - evaluation_start

        returns = log_evaluation(
            logger,
            rollouts,
            env_test,
            evaluation_time,
            step=iter,
            agent_group=agent_group,
        )
        return returns


def log_env_stuff(cfg, regret, het_collection_td, hom_collection_td, logger, env, iter):
    logger.experiment.log(
        env.scenario.to_log(),
        commit=False,
    )
    to_log = {"regret": regret}

    logger.experiment.log(
        to_log,
        commit=False,
    )

    if (
        cfg.env.scenario_name.startswith("flag_capture")
        and cfg.eval.evaluation_episodes > 0
        and iter % cfg.eval.evaluation_interval == 0
        and cfg.logger.backend == "wandb"
    ):
        fig = plot_gen_agg(env.scenario.task_agg, range=(0, 1), device=cfg.env.device)
        wandb.log({"GenAgg_task": wandb.Image(fig)}, commit=False)
        matplotlib.pyplot.close()
        fig = plot_gen_agg(env.scenario.agent_agg, range=(0, 1), device=cfg.env.device)
        wandb.log({"GenAgg_agent": wandb.Image(fig)}, commit=False)
        matplotlib.pyplot.close()
