from typing import Optional, List, Dict, Union

import numpy as np
import torch
import wandb
from tensordict.tensordict import TensorDictBase, TensorDict

from src.models.gp_model import MultitaskGPModel
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.environment import Environment
from vmas.simulator.utils import DEVICE_TYPING


class HallucEnvironment(Environment):
    metadata = {
        "render.modes": ["human", "rgb_array"],
        "runtime.vectorized": True,
    }

    def __init__(
        self,
        scenario: BaseScenario,
        num_envs: int = 32,
        device: DEVICE_TYPING = "cpu",
        max_steps: Optional[int] = None,
        continuous_actions: bool = True,
        seed: Optional[int] = None,
        dict_spaces: bool = False,
        **kwargs,
    ):
        self.use_k_branching = kwargs.pop("use_k_branching", True)
        self.use_optimism = kwargs.pop("use_optimism", False)
        self.use_hucrl_approx = kwargs.pop("use_hucrl_approx", False)
        self.optimism_after_iter = kwargs.pop("optimism_after_iter", 100)
        self.beta = kwargs.pop("beta", 1)
        self.num_samples = kwargs.pop("num_samples", 5)
        self.lower_percentile = kwargs.pop("lower_percentile", 0.9)
        self.upper_percentile = kwargs.pop("upper_percentile", 1)
        self.logger = kwargs.pop("logger", None)
        self.batch_size = torch.zeros((num_envs)).shape
        super().__init__(scenario, num_envs, device, max_steps, continuous_actions, seed, dict_spaces, **kwargs)
    
    def register_policy(self, policy):
        self.policy = policy

    def anneal_beta(self, beta):
        self.beta = beta
    
    def anneal_lower_percentile(self, lower_percentile):
        self.lower_percentile = lower_percentile

    def real_step( # Take a step in the real environment without optimism
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int
        ):

        # Get the previous observation and action
        prev_obs = tensordict.get(("agents", "observation"))
        action = tensordict.get(("agents", "action"))

        # Set the environment to match the state sampled from the replay buffer
        if self.use_k_branching:
            replay_info = tensordict.get(("replay_info"))
            self.scenario.set_env_to_obs(prev_obs, replay_info)

        # Execute the action
        action_processed = self.preprocess_action(action)
        for i, agent in enumerate(self.agents): # set action for each agent
            self._set_action(action_processed[i], agent)
        for agent in self.world.agents: # update scripted agents' action processors
            self.scenario.env_process_action(agent)
        self.world.step() # advance the world state

        self.scenario.steps += 1
        obs, rewards, dones, infos = self.get_from_scenario(
            get_observations=True, get_infos=True, get_rewards=True, get_dones=True, get_replay_info=False
        )

        return obs, rewards, dones, infos
    
    def optimistic_real_step( # Optimistic step in the real environment
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int
        ):

        # Get the previous observation and action
        prev_obs = tensordict.get(("agents", "observation"))
        action = tensordict.get(("agents", "action"))
        replay_info = tensordict.get(("replay_info"))

        # Set the environment to match the state sampled from the replay buffer
        if self.use_k_branching:
            self.scenario.set_env_to_obs(prev_obs, replay_info)

        # Execute the action
        action_processed = self.preprocess_action(action)
        for i, agent in enumerate(self.agents): # set action for each agent
            self._set_action(action_processed[i], agent)
        for agent in self.world.agents: # update scripted agents' action processors
            self.scenario.env_process_action(agent)
        self.world.step() # advance the world state

        # Compute the next observation
        self.scenario.steps += 1
        next_obs = self.get_from_scenario(
            get_observations=True, get_infos=False, get_rewards=False, get_dones=False, get_replay_info=False
        )

        # Initialize return quantities
        all_rewards = []
        all_dones = []
        all_obs = []
        default_infos = [{} for i in range(self.n_agents)]

        # Add plausible displacements to the next obs
        next_agent_pos_vel = next_obs[0][0][:, :self.scenario.output_dim]
        for i in range(self.num_samples):
            perturbed_next_obs = self.scenario.get_perturbed_obs(next_agent_pos_vel, self.beta, 0.02)
            rewards, dones = self.get_from_scenario( # does not mark candidate as sampled
                get_observations=False, get_infos=False, get_rewards=True, get_dones=True, get_replay_info=False, sample=False,
            )
            all_rewards.append(torch.stack(rewards, dim=0))
            all_dones.append(dones)
            next_obs = [perturbed_next_obs]
            all_obs.append(torch.stack(next_obs, dim=0))

        # Select the next observation leading to the highest reward
        all_rewards = torch.stack(all_rewards, dim=0)[:,0] # rewards are the same for all agents
        all_dones = torch.stack(all_dones, dim=0)
        all_obs = torch.stack(all_obs, dim=0)
        joint_max_rewards = all_rewards.max(dim=0).values
        max_rewards = [joint_max_rewards for i in range(self.n_agents)]
        joint_indices = all_rewards.max(dim=0).indices[None,None,:,None].expand(all_obs.shape)
        joint_max_obs = all_obs.gather(dim=0, index=joint_indices)[0]
        max_obs = [joint_max_obs[i] for i in range(self.n_agents)]
        max_dones = all_dones.gather(dim=0, index=all_rewards.max(dim=0).indices.view(1,-1)).squeeze(0)

        # Log reward statistics
        if total_halluc_env_steps % 10 == 0:
            wandb.log({
                "train/total_env_steps": total_env_steps,
                "train/total_halluc_env_steps": total_halluc_env_steps,
                "train/iteration": iteration,
                "train/sample_reward/reward_min": all_rewards.min(dim=0).values.mean().item(),
                "train/sample_reward/reward_max": joint_max_rewards.mean().item(),
                "train/sample_reward/reward_mean": all_rewards.mean().item(),
                "train/sample_reward/reward_std" : all_rewards.std(0).mean(),
            })

        return max_obs, max_rewards, max_dones, default_infos
    
    def step( # Take a step in the hallucinated environment without optimism
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int
        ):

        if self.use_optimism and iteration >= self.optimism_after_iter:
            if self.use_hucrl_approx:
                return self.hucrl_optimistic_step(tensordict, total_env_steps, total_halluc_env_steps, iteration)
            else:
                return self.new_optimistic_step(tensordict, total_env_steps, total_halluc_env_steps, iteration)

        # Get the previous observation and action
        prev_obs = tensordict.get(("agents", "observation"))
        action = tensordict.get(("agents", "action"))

        # Set the environment to match the state sampled from the replay buffer
        if self.use_k_branching:
            replay_info = tensordict.get(("replay_info"))
            self.scenario.set_env_to_obs(prev_obs, replay_info)

        # Execute the action
        action_processed = self.preprocess_action(action)
        for i, agent in enumerate(self.agents): # set action for each agent
            self._set_action(action_processed[i], agent)
        for agent in self.world.agents: # update scripted agents' action processors
            self.scenario.env_process_action(agent)

        # Predict the next observation under the dynamics model
        joint_input = self.scenario.get_joint_input(prev_obs, action)
        output = self.model.posterior_prediction(joint_input) # predict the change in observation
        next_agent_obs = output.mean[:, 0, :self.scenario.pos_vel_dim] + joint_input[:, :self.scenario.pos_vel_dim]
        self.scenario.set_agent_pos_vel(next_agent_obs)
        next_obs_condensed = self.scenario.observation(self.world.agents[0])
        next_obs = [next_obs_condensed] # TODO: support more agents
        # prev_action = (next_obs_condensed[:, 2:4] - prev_obs[:, :, 2:4] * 0.75) / 0.1 # from integrate state in vmas/simulator/core.py

        self.scenario.steps += 1
        if self.model.learn_smoothed_reward:
            rewards = [output.mean[:, 0, 4]]
            rewards[0][self.scenario.sampled_at_pos(next_agent_obs[:,:2]) == 0] = 0 # zero out reward at previously sampled cells
            real_rewards, dones, infos = self.get_from_scenario(
                get_observations=False, get_infos=True, get_rewards=True, get_dones=True, get_replay_info=False
            )
        elif self.model.learn_unsmoothed_reward:
            rewards = [output.mean[:, 0, 4]]
            real_rewards, dones, infos = self.get_from_scenario(
                get_observations=False, get_infos=True, get_rewards=True, get_dones=True, get_replay_info=False
            )
        else:
            rewards, dones, infos = self.get_from_scenario(
                get_observations=False, get_infos=True, get_rewards=True, get_dones=True, get_replay_info=False
            )

        if self.logger is not None and total_halluc_env_steps % 10 == 0:
            to_log = {
                "train/total_env_steps": total_env_steps,
                "train/total_halluc_env_steps": total_halluc_env_steps,
                "train/iteration": iteration,
                "train_gp/obs_delta_stddev": output.stddev.mean(),
            }
            if (self.model.learn_smoothed_reward or self.model.learn_unsmoothed_reward):
                to_log.update({"train_learned_reward/mean_error": torch.abs(real_rewards[0] - rewards[0]).mean()})
            wandb.log(to_log)

        return next_obs, rewards, dones, infos #, prev_action

    def hucrl_optimistic_step(
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int
        ):

        if self.model.learn_smoothed_reward or self.model.learn_unsmoothed_reward:
            return self.hucrl_thompson_sampling_step(tensordict, total_env_steps, total_halluc_env_steps, iteration)

        # Get the previous observation and action
        prev_obs = tensordict.get(("agents", "observation"))
        action = tensordict.get(("agents", "action"))

        # Set the environment to match the state sampled from the replay buffer
        if self.use_k_branching:
            replay_info = tensordict.get(("replay_info"))
            self.scenario.set_env_to_obs(prev_obs, replay_info)

        # Execute the action
        action_processed = self.preprocess_action(action)
        for i, agent in enumerate(self.agents): # set action for each agent
            self._set_action(action_processed[i], agent)
        for agent in self.world.agents: # update scripted agents' action processors
            self.scenario.env_process_action(agent)

        # Predict the next observation under the dynamics model
        joint_input = self.scenario.get_joint_input(prev_obs, action)
        output = self.model.posterior_prediction(joint_input)
        next_agent_obs = output.mean[:, 0, :self.scenario.pos_vel_dim] + joint_input[:, :self.scenario.pos_vel_dim]
        agent_obs_stddev = output.stddev[:, 0, :self.scenario.pos_vel_dim]

        # Initialize return quantities
        all_rewards = []
        all_dones = []
        all_obs = []

        # Add plausible displacements to the next obs
        for i in range(self.num_samples):
            perturbed_next_obs = self.scenario.get_perturbed_obs(next_agent_obs, self.beta, agent_obs_stddev)
            rewards, dones = self.get_from_scenario(
                get_observations=False, get_infos=False, get_rewards=True, get_dones=True, get_replay_info=False, sample=False,
            )
            all_rewards.append(torch.stack(rewards, dim=0))
            all_dones.append(dones)
            next_obs = [perturbed_next_obs]
            all_obs.append(torch.stack(next_obs, dim=0))

        # Select the next observation leading to the highest reward
        all_rewards = torch.stack(all_rewards, dim=0).squeeze(dim=1)
        all_dones = torch.stack(all_dones, dim=0)
        all_obs = torch.stack(all_obs, dim=0)
        joint_indices = all_rewards.max(dim=0).indices[None,None,:,None].expand(all_obs.shape)
        max_obs = all_obs.gather(dim=0, index=joint_indices)[0].squeeze(0)
        self.scenario.set_agent_pos_vel(max_obs)
        # max_prev_action = (joint_max_obs[:, :, 2:4] - prev_obs[:, :, 2:4] * 0.75) / 0.1 # from integrate state in vmas/simulator/core.py

        self.scenario.steps += 1
        obs, rewards, dones, infos = self.get_from_scenario( # need to sample at the selected state
            get_observations=True, get_infos=True, get_rewards=True, get_dones=True, get_replay_info=False
        )

        # Log reward statistics
        if self.logger is not None and total_halluc_env_steps % 10 == 0:
            wandb.log({
                "train/total_env_steps": total_env_steps,
                "train/total_halluc_env_steps": total_halluc_env_steps,
                "train/iteration": iteration,
                "train_gp/obs_delta_stddev": output.stddev.mean(),
                "train_sample_reward/reward_min": all_rewards.min(dim=0).values.mean().item(),
                "train_sample_reward/reward_max": rewards[0].mean().item(),
                "train_sample_reward/reward_mean": all_rewards.mean().item(),
                "train_sample_reward/reward_std" : all_rewards.std(0).mean(),
            })
        return [max_obs], rewards, dones, infos #, max_prev_action

    def hucrl_thompson_sampling_step(
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int
        ):

        # Get the previous observation and action
        prev_obs = tensordict.get(("agents", "observation"))
        action = tensordict.get(("agents", "action"))

        # Set the environment to match the state sampled from the replay buffer
        if self.use_k_branching:
            replay_info = tensordict.get(("replay_info"))
            self.scenario.set_env_to_obs(prev_obs, replay_info)

        # Execute the action
        action_processed = self.preprocess_action(action)
        for i, agent in enumerate(self.agents): # set action for each agent
            self._set_action(action_processed[i], agent)
        for agent in self.world.agents: # update scripted agents' action processors
            self.scenario.env_process_action(agent)

        # Initialize return quantities
        all_rewards = []
        all_obs = []

        # Do Thompson sampling x times and select the best
        joint_input = self.scenario.get_joint_input(prev_obs, action)
        for i in range(self.num_samples):
            # Sample the next optimistic observation under the dynamics and reward model
            next_obs_delta, optimistic_rewards, posterior = self.model.optimistic_posterior_prediction(joint_input, self.lower_percentile, self.upper_percentile)
            next_agent_obs = next_obs_delta + joint_input[:, :self.scenario.pos_vel_dim]
            if self.model.learn_smoothed_reward:
                optimistic_rewards[self.scenario.sampled_at_pos(next_agent_obs[:,:2]) == 0] = 0 # zero out reward at previously sampled cells

            all_rewards.append(torch.stack([optimistic_rewards], dim=0))
            all_obs.append(torch.stack([next_agent_obs], dim=0))

        # Select the next observation leading to the highest reward
        all_rewards = torch.stack(all_rewards, dim=0).squeeze(dim=1)
        all_obs = torch.stack(all_obs, dim=0)
        joint_indices = all_rewards.max(dim=0).indices[None,None,:,None].expand(all_obs.shape)
        max_rewards = [all_rewards.max(dim=0).values]
        max_obs = all_obs.gather(dim=0, index=joint_indices)[0].squeeze(0)
        self.scenario.set_agent_pos_vel(max_obs)
        next_obs_condensed = self.scenario.observation(self.world.agents[0])
        next_obs = [next_obs_condensed]

        self.scenario.steps += 1
        dones, infos = self.get_from_scenario( # need to sample at the selected state
            get_observations=False, get_infos=True, get_rewards=False, get_dones=True, get_replay_info=False
        )

        # Log reward statistics
        if self.logger is not None and total_halluc_env_steps % 10 == 0:
            wandb.log({
                "train/total_env_steps": total_env_steps,
                "train/total_halluc_env_steps": total_halluc_env_steps,
                "train/iteration": iteration,
                "train_gp/obs_delta_stddev": posterior.stddev.mean(),
                "train_sample_reward/reward_min": all_rewards.min(dim=0).values.mean().item(),
                "train_sample_reward/reward_max": max_rewards[0].mean().item(),
                "train_sample_reward/reward_mean": all_rewards.mean().item(),
                "train_sample_reward/reward_std" : all_rewards.std(0).mean(),
            })
        return next_obs, max_rewards, dones, infos #, max_prev_action

    def new_optimistic_step(
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int
        ):

        # Get the previous observation and action
        prev_obs = tensordict.get(("agents", "observation"))
        action = tensordict.get(("agents", "action"))

        # Set the environment to match the state sampled from the replay buffer
        if self.use_k_branching:
            replay_info = tensordict.get(("replay_info"))
            self.scenario.set_env_to_obs(prev_obs, replay_info)

        # Execute the action
        action_processed = self.preprocess_action(action)
        for i, agent in enumerate(self.agents): # set action for each agent
            self._set_action(action_processed[i], agent)
        for agent in self.world.agents: # update scripted agents' action processors
            self.scenario.env_process_action(agent)

        # Predict the next optimistic observation under the dynamics and reward model
        joint_input = self.scenario.get_joint_input(prev_obs, action)
        next_obs_delta, optimistic_rewards, posterior = self.model.optimistic_posterior_prediction(joint_input, self.lower_percentile, self.upper_percentile)
        next_agent_obs = next_obs_delta + joint_input[:, :self.scenario.pos_vel_dim]
        self.scenario.set_agent_pos_vel(next_agent_obs)
        next_obs_condensed = self.scenario.observation(self.world.agents[0])
        next_obs = [next_obs_condensed]

        self.scenario.steps += 1
        if self.model.learn_smoothed_reward:
            optimistic_rewards[self.scenario.sampled_at_pos(next_agent_obs[:,:2]) == 0] = 0 # zero out reward at previously sampled cells
        optimistic_rewards = [optimistic_rewards]
        real_rewards, dones, infos = self.get_from_scenario(
            get_observations=False, get_infos=True, get_rewards=True, get_dones=True, get_replay_info=False
        )

        if self.logger is not None and total_halluc_env_steps % 10 == 0:
            mean_rewards = posterior.mean[:, 0, 4]
            to_log = {
                "train/total_env_steps": total_env_steps,
                "train/total_halluc_env_steps": total_halluc_env_steps,
                "train/iteration": iteration,
                "train_gp/obs_delta_stddev": posterior.stddev.mean(),
                "train_learned_reward/mean_error": torch.abs(real_rewards[0] - mean_rewards).mean(),
                "train_learned_reward/sample_error": torch.abs(real_rewards[0] - optimistic_rewards[0]).mean(),
                "train_learned_reward/lower_reward_percentile": self.lower_percentile,
                "train_learned_reward/upper_reward_percentile": self.upper_percentile,
            }
            wandb.log(to_log)

        return next_obs, optimistic_rewards, dones, infos