import random
import wandb
import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.data import TensorDictReplayBuffer, TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
    ProbabilisticActor, 
    TanhDelta, 
    ValueOperator, 
    EGreedyModule, 
    QValueModule, 
    SafeSequential, 
    AdditiveGaussianWrapper
)
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DQNLoss, DDPGLoss, ValueEstimators, SoftUpdate
from src.utils.logging import Logger
from src.utils.pink_noise_wrapper import PinkNoiseWrapper
from src.utils.utils import DoneTransform


def rendering_callback(env, td):
    env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))


class MARL():

    def __init__(self, config):
        self.config = config
        self.device = self.config.gpu_device
        self.config.policy.frames_per_batch = self.config.env.num_envs * self.config.env.max_steps
        self.config.env.seed = random.randint(1,1000)

        # Logging
        if self.config.logger.backend:
            self.logger = Logger(self.config, project=self.config.logger.project_name)

        # Create the environment
        self.env = VmasEnv(
            scenario=self.config.scenario.name,
            num_envs=self.config.env.num_envs,
            continuous_actions=self.config.scenario.use_continuous_actions,
            max_steps=self.config.env.max_steps,
            seed=self.config.env.seed,
            device=self.device,
            # Scenario kwargs
            **self.config.scenario,
        )
        self.env = TransformedEnv(
            self.env,
            RewardSum(in_keys=[self.env.reward_key], out_keys=[("agents", "episode_reward")]),
        )
        self.done_transform = DoneTransform(reward_key=self.env.reward_key, done_keys=self.env.done_keys)

        self.config.scenario.state_dim = self.env.observation_spec["agents", "observation"].shape[-1]
        self.config.scenario.action_dim = self.env.action_spec.shape[-1] if self.config.scenario.use_continuous_actions else 1
        print(f"State dimension: {self.config.scenario.state_dim}, action spec dimension: {self.config.scenario.action_dim}")

        self.env_test = VmasEnv(
            scenario=self.config.scenario.name,
            num_envs=self.config.eval_env.evaluation_episodes,
            continuous_actions=self.config.scenario.use_continuous_actions,
            max_steps=self.config.eval_env.max_steps,
            seed=self.config.seed,
            device=self.device,
            # Scenario kwargs
            **self.config.scenario,
        )

        if self.config.logger.save_data:
            self.obs_inputs = torch.empty((0, self.env.scenario.input_dim)).to(self.device)
            self.obs_targets = torch.empty((0, self.env.scenario.output_dim)).to(self.device)

    def train(self):
        total_env_steps = 0
        total_policy_updates = 0

        iteration = 0
        while total_env_steps < self.config.total_env_steps:
            if iteration % 10 == 0:
                print(f"Training iteration {iteration}")

            # Roll out the current policy in the real environment
            if self.config.explore.explore_type == "none":
                with torch.no_grad():
                    tensordict_data = self.env.rollout(
                        max_steps=self.config.env.max_steps,
                        policy=self.policy_explore,
                        auto_cast_to_device=True,
                        break_when_any_done=False,
                        auto_reset=True,
                    )
                    tensordict_data = self.done_transform(tensordict_data)
            else:
                with torch.no_grad() and set_exploration_type(ExplorationType.RANDOM):
                    tensordict_data = self.env.rollout(
                        max_steps=self.config.env.max_steps,
                        policy=self.policy_explore,
                        auto_cast_to_device=True,
                        break_when_any_done=False,
                        auto_reset=True,
                    )
                    tensordict_data = self.done_transform(tensordict_data)

            if self.config.logger.save_data:
                obs_inputs, obs_targets = self.env.get_transition_data(tensordict_data)
                self.obs_inputs = torch.cat((obs_inputs, self.obs_inputs), dim=0)
                self.obs_targets = torch.cat((obs_targets, self.obs_targets), dim=0)

            # Execute custom training iteration
            training_tds, env_steps, policy_updates = self.train_iteration(tensordict_data)
            total_env_steps += env_steps
            total_policy_updates += policy_updates

            # Log per-iteration statistics
            if self.config.logger.backend:
                training_tds = torch.stack(training_tds)
                self.logger.log_training(iteration, training_tds, tensordict_data, total_env_steps, total_policy_updates)

            # Evaluate learning
            if (self.config.eval_env.evaluation_episodes > 0 and iteration % self.config.eval_env.evaluation_interval == 0 and self.config.logger.backend):
                self.evaluation(iteration, total_env_steps, total_policy_updates)

            # Save the dataset and models
            if self.config.logger.save_data and iteration % self.config.eval_env.evaluation_interval == 0:
                self.save_training_states()

            iteration += 1
        
        # Evaluate final performance
        self.evaluation(iteration, total_env_steps, total_policy_updates)
        if self.config.logger.save_data:
            self.save_training_states()

        self.logger.finish()

    def evaluation(self, i, total_env_steps, total_policy_updates):
        with set_exploration_type(
            ExplorationType.MODE
            if self.config.exploration.eval_deterministic_actions
            else ExplorationType.RANDOM
        ):
            self.env_test.frames = []
            rollouts = self.env_test.rollout(
                max_steps=self.config.eval_env.max_steps,
                policy=self.policy_explore,
                callback=rendering_callback,
                auto_cast_to_device=True,
                break_when_any_done=False,
                # We are running vectorized evaluation we do not want it to stop when just one env is done
            )
            self.logger.log_evaluation(i, rollouts, self.env_test, total_env_steps, total_policy_updates)

    def save_training_states(self):
        torch.save(self.obs_inputs, f"{self.config.logger.output_dir}/train_inputs.pt")
        torch.save(self.obs_targets, f"{self.config.logger.output_dir}/train_targets.pt")
        torch.save(self.policy, f"{self.config.logger.output_dir}/policy.pt")


class IQL(MARL):

    def __init__(self, config):
        super().__init__(config)

        # Initialize independent Q networks for each agent
        net = MultiAgentMLP(
            n_agent_inputs=self.config.scenario.state_dim,
            n_agent_outputs=self.env.action_spec.space.n,
            n_agents=self.config.scenario.n_agents,
            centralised=False,
            share_params=self.config.policy.shared_parameters,
            device=self.device,
            depth=2,
            num_cells=256,
            activation_class=nn.Tanh,
        )
        module = TensorDictModule(
            net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")]
        )

        # Find argmax of action values
        value_module = QValueModule(
            action_value_key=("agents", "action_value"),
            out_keys=[
                ("agents", "action"),
                ("agents", "action_value"),
                ("agents", "chosen_action_value"),
            ],
            spec=self.env.unbatched_action_spec,
            action_space=None,
        )
        self.policy = SafeSequential(module, value_module) # This is really the Q network

        if self.config.exploration.explore_type == "e-greedy":
            self.policy_explore = TensorDictSequential(
                self.policy,
                EGreedyModule(
                    eps_init=0.3,
                    eps_end=0,
                    annealing_num_steps=int(self.config.total_env_steps * (1 / 2)),
                    action_key=self.env.action_key,
                    spec=self.env.unbatched_action_spec,
                ),
            )
        elif self.config.exploration.explore_type == "none":
            self.policy_explore = self.policy
        else:
            raise RuntimeError("Unsupported exploration type.")

        self.replay_buffer = TensorDictReplayBuffer(
            storage=LazyTensorStorage(self.config.policy.minibatch_size, device=self.device),
            sampler=SamplerWithoutReplacement(shuffle=True, seed=self.config.seed, device=self.device),
            batch_size=self.config.policy.minibatch_size,
        )

        self.loss_module = DQNLoss(self.policy, delay_value=True)
        self.loss_module.set_keys(
            action_value=("agents", "action_value"),
            action=self.env.action_key,
            value=("agents", "chosen_action_value"),
            reward=self.env.reward_key,
            done=("agents", "done"),
            terminated=("agents", "terminated"),
        )
        self.loss_module.make_value_estimator(ValueEstimators.TD0, gamma=self.config.loss.gamma)
        self.target_net_updater = SoftUpdate(self.loss_module, eps=1 - self.config.loss.tau)
        self.optim = torch.optim.Adam(self.loss_module.parameters(), self.config.policy.lr)

    def train_iteration(self, tensordict_data):
        env_steps = tensordict_data.numel()
        data_view = tensordict_data.reshape(-1)
        self.replay_buffer.extend(data_view)

        policy_updates = 0
        training_tds = []
        for _ in range(self.config.policy.num_epochs):
            for _ in range(self.config.policy.frames_per_batch // self.config.policy.minibatch_size):
                subdata = self.replay_buffer.sample()
                loss_vals = self.loss_module(subdata)
                training_tds.append(loss_vals.detach())

                loss_value = loss_vals["loss"]
                loss_value.backward()
                total_norm = torch.nn.utils.clip_grad_norm_(self.loss_module.parameters(), self.config.policy.max_grad_norm)
                training_tds[-1].set("grad_norm", total_norm.mean())
                
                self.optim.step() # update policy weights
                self.optim.zero_grad()
                self.target_net_updater.step()
                policy_updates += 1
        if hasattr(self.policy_explore[1], "step"):
            self.policy_explore[1].step(frames=env_steps) # update exploration annealing
        return training_tds, env_steps, policy_updates
    

class DDPG(MARL):

    def __init__(self, config):
        super().__init__(config)
        activation = nn.Mish if self.config.policy.activation == "Mish" else nn.Tanh

        # Initialize independent actors for each agent
        actor_net = MultiAgentMLP(
            n_agent_inputs=self.config.scenario.state_dim,
            n_agent_outputs=self.env.action_spec.shape[-1],
            n_agents=self.config.scenario.n_agents,
            centralised=False,
            share_params=self.config.policy.shared_parameters,
            device=self.device,
            depth=2,
            num_cells=256,
            activation_class=activation,
        )
        policy_module = TensorDictModule(
            actor_net, in_keys=[("agents", "observation")], out_keys=[("agents", "param")]
        )
        self.policy = ProbabilisticActor(
            module=policy_module,
            spec=self.env.unbatched_action_spec,
            in_keys=[("agents", "param")],
            out_keys=[self.env.action_key],
            distribution_class=TanhDelta,
            distribution_kwargs={
                "min": self.env.unbatched_action_spec[("agents", "action")].space.low,
                "max": self.env.unbatched_action_spec[("agents", "action")].space.high,
            },
            return_log_prob=False,
        )

        if self.config.exploration.explore_type == "additive gaussian":
            self.policy_explore = AdditiveGaussianWrapper(
                self.policy,
                annealing_num_steps=int(self.config.total_env_steps * (1 / 2)),
                action_key=self.env.action_key,
                seed=self.config.seed,
            )
        elif self.config.exploration.explore_type == "pink noise":
            self.policy_explore = PinkNoiseWrapper(
                self.policy,
                batch_size=self.config.env.num_envs,
                seq_len=self.config.env.max_steps,
                annealing_num_steps=self.config.total_env_steps,
                random_num_steps=0,
                action_key=("agents", "action"),
                sigma_init=self.config.exploration.eps_init,
                sigma_end=self.config.exploration.eps_end,
            )
        elif self.config.exploration.explore_type == "none":
            self.policy_explore = self.policy
        else:
            raise RuntimeError("Unsupported exploration type.")

        # Initialize the critic
        module = MultiAgentMLP(
            n_agent_inputs=self.config.scenario.state_dim + self.env.action_spec.shape[-1],  # Q critic takes action and value
            n_agent_outputs=1,
            n_agents=self.config.scenario.n_agents,
            centralised=self.config.policy.centralized_critic,
            share_params=self.config.policy.shared_parameters,
            device=self.device,
            depth=2,
            num_cells=256,
            activation_class=activation,
        )
        self.value_module = ValueOperator(
            module=module,
            in_keys=[("agents", "observation"), self.env.action_key],
            out_keys=[("agents", "state_action_value")],
        )

        if self.config.replay_buffer.use_priority:
            self.replay_buffer = TensorDictPrioritizedReplayBuffer(
                alpha=0.7,
                beta=0.5,
                storage=LazyTensorStorage(self.config.replay_buffer.memory_size, device=self.device),
                batch_size=self.config.replay_buffer.minibatch_size,
                priority_key=("agents", "td_error"),
            )
        else:
            self.replay_buffer = TensorDictReplayBuffer(
                storage=LazyTensorStorage(self.config.replay_buffer.memory_size, device=self.device),
                sampler=SamplerWithoutReplacement(shuffle=True, seed=self.config.seed, device=self.device),
                batch_size=self.config.replay_buffer.minibatch_size,
            )

        # Set up DDPG loss for each agent
        self.loss_module = DDPGLoss(actor_network=self.policy, value_network=self.value_module, delay_value=True)
        self.loss_module.set_keys(
            state_action_value=("agents", "state_action_value"),
            reward=self.env.reward_key,
            priority=("agents", "td_error"),
            done=("agents", "done"),
            terminated=("agents", "terminated"),
        )
        self.loss_module.make_value_estimator(ValueEstimators.TD0, gamma=self.config.loss.gamma)
        self.target_net_updater = SoftUpdate(self.loss_module, eps=1 - self.config.loss.tau)
        self.optim = torch.optim.Adam(self.loss_module.parameters(), lr=self.config.learning_rate.lr, eps=self.config.learning_rate.adam_eps)
        if self.config.learning_rate.use_scheduler:
            self.scheduler = torch.optim.lr_scheduler.LinearLR(
                self.optim,
                start_factor=self.config.learning_rate.lr_start / self.config.learning_rate.lr,
                end_factor=1.0,
                total_iters=self.config.learning_rate.lr_iters,
            )

    def train_iteration(self, tensordict_data):
        env_steps = tensordict_data.numel()
        data_view = tensordict_data.reshape(-1)
        self.replay_buffer.extend(data_view)

        policy_updates = 0
        training_tds = []
        for _ in range(self.config.policy.num_epochs):
            for _ in range(self.config.policy.frames_per_batch // self.config.replay_buffer.minibatch_size):
                subdata = self.replay_buffer.sample()
                loss_vals = self.loss_module(subdata)
                training_tds.append(loss_vals.detach())

                loss_value = loss_vals["loss_actor"] + loss_vals["loss_value"]
                loss_value.backward()
                total_norm = torch.nn.utils.clip_grad_norm_(self.loss_module.parameters(), self.config.policy.max_grad_norm)
                training_tds[-1].set("grad_norm", total_norm.mean())
                self.replay_buffer.update_tensordict_priority(subdata)
                
                self.optim.step() # update policy weights
                self.optim.zero_grad()
                self.target_net_updater.step()
                policy_updates += 1

        if hasattr(self.policy_explore, "step"):
            self.policy_explore.step(frames=env_steps) # update exploration annealing
        if self.config.policy.use_scheduler:
            self.scheduler.step()
        return training_tds, env_steps, policy_updates