import torch
import random
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.data import TensorDictReplayBuffer, TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from src.wrappers import HallucVmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
    ProbabilisticActor, 
    TanhDelta, 
    ValueOperator, 
    QValueModule, 
    SafeSequential, 
)
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DQNLoss, DDPGLoss, ValueEstimators, SoftUpdate
from src.marl_systems import rendering_callback
from src.utils.logging import Logger
from src.utils.utils import DoneTransform
from src.models.gp_model import MultitaskGPModel
from src.models.nn_model import NNModel


class Halluc_MARL():

    def __init__(self, config):
        self.config = config
        self.config.env.seed = random.randint(1,1000)
        self.config.halluc_env.seed = random.randint(1,1000)
        self.device = self.config.gpu_device
        self.config.policy.frames_per_batch = self.config.env.num_envs * self.config.env.max_steps
        self.config.total_halluc_env_steps = self.config.total_env_steps * self.config.halluc_env.num_envs / self.config.env.num_envs
        self.config.policy.frames_per_batch = self.config.halluc_env.num_envs * self.config.halluc_env.max_steps
        print(f"State dimension: {self.config.scenario.state_dim}, action spec dimension: {self.config.scenario.action_dim}")

        if self.config.model.learn_smoothed_reward and self.config.model.learn_unsmoothed_reward:
            raise RuntimeError("Learning the smoothed reward is exclusive with the unsmoothed reward")

        if not self.config.logger.backend:
            self.config.model.num_epochs = 1 # to speed up testing

        # Create the environment
        self.env = VmasEnv(
            scenario=self.config.scenario.name,
            num_envs=self.config.env.num_envs,
            continuous_actions=self.config.scenario.use_continuous_actions,
            max_steps=self.config.env.max_steps,
            seed=self.config.env.seed,
            device=self.device,
            # Scenario kwargs
            **self.config.scenario,
        )
        self.env = TransformedEnv(
            self.env,
            RewardSum(in_keys=[self.env.reward_key], out_keys=[("agents", "episode_reward")]),
        )

        # Create the hallucination environment
        self.halluc_env = HallucVmasEnv(
            scenario=self.config.scenario.name,
            num_envs=self.config.halluc_env.num_envs,
            continuous_actions=self.config.scenario.use_continuous_actions,
            max_steps=self.config.halluc_env.max_steps,
            seed=self.config.halluc_env.seed,
            device=self.device,
            use_k_branching=self.config.model.use_k_branching,
            use_optimism=self.config.optimism.use_optimism,
            use_hucrl_approx=self.config.optimism.use_hucrl_approx,
            optimism_after_iter=self.config.optimism.optimism_after_iter,
            beta=self.config.optimism.initial_beta,
            num_samples=self.config.optimism.num_samples,
            lower_percentile=self.config.optimism.initial_lower_percentile,
            upper_percentile=self.config.optimism.upper_percentile,
            logger=self.config.logger.backend,
            **self.config.scenario,
        )
        self.halluc_env = TransformedEnv( # initializes the EnvBase instance when called, has HallucVmasEnv as _env attribute
            self.halluc_env,
            RewardSum(in_keys=[self.halluc_env.reward_key], out_keys=[("agents", "episode_reward")]),
        )

        # Create the test environment
        self.env_test = VmasEnv(
            scenario=self.config.scenario.name,
            num_envs=self.config.eval_env.evaluation_episodes,
            continuous_actions=self.config.scenario.use_continuous_actions,
            max_steps=self.config.eval_env.max_steps,
            seed=self.config.seed,
            device=self.device,
            # Scenario kwargs
            **self.config.scenario,
        )

        # Create the transition model
        if self.config.model.model_type == "GP":
            print("Using GP to learn dynamics")
            self.model = MultitaskGPModel(
                learn_smoothed_reward=self.config.model.learn_smoothed_reward,
                learn_unsmoothed_reward=self.config.model.learn_unsmoothed_reward,
                num_epochs=self.config.model.num_epochs,
                minibatch_size=self.config.model.minibatch_size,
                input_dim=self.env.scenario.input_dim,
                pos_vel_dim=self.env.scenario.pos_vel_dim,
                max_nn_dataset_size=self.config.model.max_nn_dataset_size,
                max_gp_dataset_size=self.config.model.max_gp_dataset_size,
                num_inducing_points=self.config.model.num_inducing_points,
                hidden_layer_width=self.config.model.hidden_layer_width,
                gp_learning_rate=self.config.model.gp_lr,
                nn_learning_rate=self.config.model.nn_lr,
                use_separate_reward_cov=self.config.model.use_separate_reward_cov,
                use_coregionalization=self.config.model.use_coregionalization,
                use_thompson_sampling=self.config.optimism.use_thompson_sampling,
                device=self.device,
            )
        elif self.config.model.model_type == "NN":
            print("Using a neural network to learn dynamics")
            self.model = NNModel( # TODO: try the calibrated neural network approach Aidan's student used (from H-UCRL)
                num_epochs=self.config.model.num_epochs,
                minibatch_size=self.config.model.minibatch_size,
                input_dim=self.env.scenario.input_dim,
                output_dim=self.env.scenario.output_dim,
                max_dataset_size=self.config.model.max_dataset_size,
                hidden_layer_width=self.config.model.num_inducing_points,
                learning_rate=self.config.model.nn_lr,
                device=self.device,
            )
        else:
            raise NotImplementedError("Model type is not supported.")
        
        self.env.register_model(self.model)
        self.halluc_env.register_model(self.model)

        # Initialize the replay buffer (for K-branched rollouts)
        if self.config.replay_buffer.use_priority:
            if self.config.scenario.name == "navigation":
                self.replay_buffer = TensorDictPrioritizedReplayBuffer(
                    alpha=0.7,
                    beta=1.1,
                    storage=LazyTensorStorage(self.config.replay_buffer.buffer_size, device=self.device),
                    # priority_key=("next", "agents", "reward"),
                    batch_size=self.config.halluc_env.num_envs,
                )
            else:
                self.replay_buffer = TensorDictPrioritizedReplayBuffer(
                    alpha=0.7,
                    beta=1.1,
                    storage=LazyTensorStorage(self.config.replay_buffer.buffer_size, device=self.device),
                    priority_key=("next", "agents", "reward"),
                    batch_size=self.config.halluc_env.num_envs,
                )
        else:
            self.replay_buffer = TensorDictReplayBuffer(
                storage=LazyTensorStorage(self.config.replay_buffer.buffer_size, device=self.device),
                sampler=SamplerWithoutReplacement(shuffle=True, seed=self.config.seed, device=self.device),
                batch_size=self.config.halluc_env.num_envs,
            )

        # Logging
        if self.config.logger.backend:
            self.logger = Logger(self.config, self.model, project=self.config.logger.project_name)

    def train(self):
        total_halluc_env_steps = 0
        total_policy_updates = 0
        total_model_updates = 0

        self.pretrain_model()
        total_env_steps = len(self.model.nn_train_dataset)

        iteration = 0
        while total_env_steps < self.config.total_env_steps:
            if iteration % 10 == 0:
                print(f"Training iteration {iteration}")

            evaluate_on_iteration = iteration % self.config.eval_env.evaluation_interval == 0 and self.config.logger.backend
            callback = None
            if evaluate_on_iteration:
                self.halluc_env.frames = []
                self.env.frames = []
                callback = rendering_callback

            # Hallucinate an optimistic trajectory
            with torch.no_grad():
                tensordict_data, total_halluc_env_steps = self.halluc_env.halluc_rollout(
                    max_steps=self.config.halluc_env.max_steps,
                    policy=self.policy,
                    callback=callback,
                    auto_cast_to_device=True,
                    break_when_any_done=False,
                    auto_reset=True,
                    replay_buffer=self.replay_buffer,
                    total_env_steps=total_env_steps,
                    total_halluc_env_steps=total_halluc_env_steps,
                    iteration=iteration,
                )
                tensordict_data = self.done_transform(tensordict_data)

            # Anneal values
            annealed_beta = self.config.optimism.initial_beta - (self.config.optimism.initial_beta - self.config.optimism.final_beta) / (self.config.total_env_steps / (self.env.max_steps * self.config.env.num_envs)) * iteration
            self.halluc_env.anneal_beta(annealed_beta)
            annealed_lower_percentile = self.config.optimism.initial_lower_percentile - (self.config.optimism.initial_lower_percentile - self.config.optimism.final_lower_percentile) / (self.config.total_env_steps / (self.env.max_steps * self.config.env.num_envs)) * iteration
            self.halluc_env.anneal_lower_percentile(annealed_lower_percentile)

            # Update the policy
            training_tds, policy_updates = self.update_policy(tensordict_data)
            total_halluc_env_steps += tensordict_data.numel()
            total_policy_updates += policy_updates

            # Roll out the current policy in the real environment
            with torch.no_grad():
                rollouts = self.env.rollout(
                    max_steps=self.config.env.max_steps,
                    policy=self.policy,
                    callback=callback,
                    auto_cast_to_device=True,
                    break_when_any_done=False, # should be false when matching IQL performance
                    auto_reset=True,
                    learn_model=True,
                )
            self.replay_buffer.extend(rollouts.reshape(-1))
            total_env_steps += rollouts.numel()

            # Update the GP model
            train_inputs, train_targets, rollout_rewards = self.env.get_transition_data(
                rollouts,
                learn_smoothed_reward=self.config.model.learn_smoothed_reward,
                learn_unsmoothed_reward=self.config.model.learn_unsmoothed_reward
            )
            self.model.register_new_data(train_inputs, train_targets, rollout_rewards)
            model_updates, nn_loss, gp_loss = self.model.train()
            total_model_updates += model_updates

            # Log per-iteration statistics
            if self.config.logger.backend:
                training_tds = torch.stack(training_tds)
                self.logger.log_training(iteration, training_tds, tensordict_data, total_env_steps, total_policy_updates, total_halluc_env_steps, total_model_updates, len(self.model.nn_train_dataset), len(self.model.gp_train_dataset), len(self.replay_buffer), annealed_beta, evaluate_on_iteration, self.halluc_env, self.env, nn_loss, gp_loss)

            # Evaluate learning
            if evaluate_on_iteration:
                self.evaluation(iteration, total_env_steps, total_halluc_env_steps, total_policy_updates, total_model_updates)
                if self.config.logger.save_data:
                    self.save_training_states()

            iteration += 1
        
        # Evaluate final performance
        self.evaluation(iteration, total_env_steps, total_halluc_env_steps, total_policy_updates, total_model_updates)
        if self.config.logger.save_data:
            self.save_training_states()

        self.logger.finish()

    def pretrain_model(self):
        print(f"Pretraining the model.")
        num_pretrain_steps = 0
        inputs = []
        targets = []
        rewards = torch.zeros((0)).to(self.device)
        with torch.no_grad():
            while num_pretrain_steps < self.config.model.num_pretrain_steps:
                rollout = self.env.rollout(
                    max_steps=self.config.env.max_steps,
                    policy=self.policy,
                    auto_cast_to_device=True,
                    break_when_any_done=False,
                    auto_reset=True,
                    learn_model=True,
                )
                num_pretrain_steps += rollout.numel()
                train_inputs, train_targets, rollout_rewards = self.env.get_transition_data(
                    rollout,
                    learn_smoothed_reward=self.config.model.learn_smoothed_reward,
                    learn_unsmoothed_reward=self.config.model.learn_unsmoothed_reward
                )
                inputs.append(train_inputs)
                targets.append(train_targets)
                rewards = torch.cat((rewards, rollout_rewards))
                self.replay_buffer.extend(rollout.reshape(-1))

        inputs = torch.vstack(inputs)
        targets = torch.vstack(targets)
        self.model.register_new_data(inputs, targets, rewards)
        self.model.train()

    def evaluation(self, i, total_env_steps, total_halluc_env_steps, total_policy_updates, total_model_updates):
        with torch.no_grad():
            self.env_test.frames = []
            rollouts = self.env_test.rollout(
                max_steps=self.config.eval_env.max_steps,
                policy=self.policy,
                callback=rendering_callback,
                auto_cast_to_device=True,
                break_when_any_done=False,
                auto_reset=True
            )
            rollout_model_input, _, _ = self.env_test.get_transition_data(rollouts[0].unsqueeze(dim=0), learn_smoothed_reward=self.config.model.learn_smoothed_reward, learn_unsmoothed_reward=self.config.model.learn_unsmoothed_reward)
            self.logger.log_evaluation(i, rollouts, self.env_test, total_env_steps, total_policy_updates, total_halluc_env_steps, total_model_updates, rollout_model_input)

    def save_training_states(self):
        torch.save(self.model, f"{self.config.logger.output_dir}/model.pt")
        torch.save(self.policy, f"{self.config.logger.output_dir}/policy.pt")
        torch.save(self.model.nn_train_dataset.inputs, f"{self.config.logger.output_dir}/nn_train_inputs.pt")
        torch.save(self.model.nn_train_dataset.targets, f"{self.config.logger.output_dir}/nn_train_targets.pt")
        torch.save(self.model.nn_priorities, f"{self.config.logger.output_dir}/nn_train_priorities.pt")


class Halluc_IQL(Halluc_MARL):
    
    def __init__(self, config):
        super().__init__(config)

        # Initialize independent Q networks for each agent
        net = MultiAgentMLP(
            n_agent_inputs=self.config.scenario.state_dim,
            n_agent_outputs=self.halluc_env.action_spec.space.n,
            n_agents=self.config.scenario.n_agents,
            centralised=False,
            share_params=self.config.policy.shared_parameters,
            device=self.device,
            depth=2,
            num_cells=256,
            activation_class=nn.Tanh,
        )
        module = TensorDictModule(
            net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")]
        )

        # Find argmax of action values
        value_module = QValueModule(
            action_value_key=("agents", "action_value"),
            out_keys=[
                ("agents", "action"),
                ("agents", "action_value"),
                ("agents", "chosen_action_value"),
            ],
            spec=self.halluc_env.unbatched_action_spec,
            action_space=None,
        )
        self.policy = SafeSequential(module, value_module) # This is really a Q network
        self.done_transform = DoneTransform(reward_key=self.env.reward_key, done_keys=self.env.done_keys)

        # Set up DQN loss for each agent
        self.loss_module = DQNLoss(self.policy, delay_value=True)
        self.loss_module.set_keys(
            action_value=("agents", "action_value"),
            action=self.halluc_env.action_key,
            value=("agents", "chosen_action_value"),
            reward=self.halluc_env.reward_key,
            done=("agents", "done"),
            terminated=("agents", "terminated"),
        )
        self.loss_module.make_value_estimator(ValueEstimators.TD0, gamma=self.config.policy.gamma)
        self.target_net_updater = SoftUpdate(self.loss_module, eps=1 - self.config.policy.tau)
        self.optim = torch.optim.Adam(self.loss_module.parameters(), self.config.policy.lr)

    def update_policy(self, tensordict_data):
        batch_data = TensorDictReplayBuffer(
            storage=LazyTensorStorage(self.config.policy.frames_per_batch, device=self.device),
            sampler=SamplerWithoutReplacement(shuffle=True, seed=self.config.seed, device=self.device),
            batch_size=self.config.policy.minibatch_size,
        )
        data_view = tensordict_data.reshape(-1)
        batch_data.extend(data_view) # fills with new data

        policy_updates = 0
        training_tds = []
        for _ in range(self.config.policy.num_epochs):
            for _ in range(self.config.policy.frames_per_batch // self.config.policy.minibatch_size):
                subdata = batch_data.sample()
                loss_vals = self.loss_module(subdata)
                training_tds.append(loss_vals.detach())

                loss_value = loss_vals["loss"]
                loss_value.backward()
                total_norm = torch.nn.utils.clip_grad_norm_(self.loss_module.parameters(), self.config.policy.max_grad_norm)
                training_tds[-1].set("grad_norm", total_norm.mean())

                self.optim.step() # update policy weights
                self.optim.zero_grad()
                self.target_net_updater.step()
                policy_updates += 1

        return training_tds, policy_updates


class Halluc_DDPG(Halluc_MARL):
    
    def __init__(self, config):
        super().__init__(config)

        # Initialize independent actors for each agent
        actor_net = MultiAgentMLP(
            n_agent_inputs=self.config.scenario.state_dim,
            n_agent_outputs=self.env.action_spec.shape[-1],
            n_agents=self.config.scenario.n_agents,
            centralised=False,
            share_params=self.config.policy.shared_parameters,
            device=self.device,
            depth=2,
            num_cells=256,
            activation_class=nn.Tanh,
        )
        policy_module = TensorDictModule(
            actor_net, in_keys=[("agents", "observation")], out_keys=[("agents", "param")]
        )
        self.policy = ProbabilisticActor(
            module=policy_module,
            spec=self.env.unbatched_action_spec,
            in_keys=[("agents", "param")],
            out_keys=[self.env.action_key],
            distribution_class=TanhDelta,
            distribution_kwargs={
                "min": self.env.unbatched_action_spec[("agents", "action")].space.low,
                "max": self.env.unbatched_action_spec[("agents", "action")].space.high,
            },
            return_log_prob=False,
        )
        self.done_transform = DoneTransform(reward_key=self.env.reward_key, done_keys=self.env.done_keys)

        # Initialize the critic
        module = MultiAgentMLP(
            n_agent_inputs=self.config.scenario.state_dim + self.env.action_spec.shape[-1],  # Q critic takes action and value
            n_agent_outputs=1,
            n_agents=self.config.scenario.n_agents,
            centralised=self.config.policy.centralized_critic,
            share_params=self.config.policy.shared_parameters,
            device=self.device,
            depth=2,
            num_cells=256,
            activation_class=nn.Tanh,
        )
        self.value_module = ValueOperator(
            module=module,
            in_keys=[("agents", "observation"), self.env.action_key],
            out_keys=[("agents", "state_action_value")],
        )

        # Set up DDPG loss for each agent
        self.loss_module = DDPGLoss(actor_network=self.policy, value_network=self.value_module, delay_value=True)
        self.loss_module.set_keys(
            state_action_value=("agents", "state_action_value"),
            reward=self.env.reward_key,
            done=("agents", "done"),
            terminated=("agents", "terminated"),
        )
        self.loss_module.make_value_estimator(ValueEstimators.TD0, gamma=self.config.policy.gamma)
        self.target_net_updater = SoftUpdate(self.loss_module, eps=1 - self.config.policy.tau)
        self.optim = torch.optim.Adam(self.loss_module.parameters(), self.config.policy.lr)

    def update_policy(self, tensordict_data):
        batch_data = TensorDictReplayBuffer(
            storage=LazyTensorStorage(self.config.policy.frames_per_batch, device=self.device),
            sampler=SamplerWithoutReplacement(shuffle=True, seed=self.config.seed, device=self.device),
            batch_size=self.config.policy.minibatch_size,
        )
        data_view = tensordict_data.reshape(-1)
        batch_data.extend(data_view) # fills with new data

        policy_updates = 0
        training_tds = []
        for _ in range(self.config.policy.num_epochs):
            for _ in range(self.config.policy.frames_per_batch // self.config.policy.minibatch_size):
                subdata = batch_data.sample()
                loss_vals = self.loss_module(subdata)
                training_tds.append(loss_vals.detach())

                loss_value = loss_vals["loss_actor"] + loss_vals["loss_value"]
                loss_value.backward()
                total_norm = torch.nn.utils.clip_grad_norm_(self.loss_module.parameters(), self.config.policy.max_grad_norm)
                training_tds[-1].set("grad_norm", total_norm.mean())
                
                self.optim.step() # update policy weights
                self.optim.zero_grad()
                self.target_net_updater.step()
                policy_updates += 1

        return training_tds, policy_updates