import numpy as np
import torch
import torch.nn.functional as F
import wandb
import random
import time
import copy


# Import the SAC implementation
from algs.sac import SAC, ReplayBuffer
from dynamics.deep_ens import EnsembleModel
from dynamics.utils import compute_uncertainty, RollingUncertaintyNormalizer


def add_transition_noise(next_state, noise_std):
    """Add Gaussian noise to state transition"""
    return next_state + np.random.normal(0, noise_std, size=next_state.shape)


class BE_MPC_SAC:
    def __init__(self,
                 env,
                 seed,
                 device,
                 uncertainty_method='IG',
                 mpc=False,
                 bayes_exp=False,
                 mod_reward=False,
                 noise_std=0.1,
                 wandb_flag=True,
                 eta_uncert=1.0,
                 lr_model=1e-3,
                 lr_sac=0.0003,
                 gamma=0.99,
                 tau=0.005,
                 alpha=0.2,
                 agent_epochs=32,
                 model_epochs=100,
                 agent_update_freq=2,
                 model_update_freq=128,
                 agent_batch_size=256,
                 model_batch_size=256,
                 horizon=64,
                 num_rollouts=16,
                 warmup_steps=500):

        self.env = env
        self.seed = seed
        self.wandb_flag = wandb_flag
        self.device = device
        self.noise_std = noise_std

        # Uncertainty parameters
        self.uncertainty_method = uncertainty_method
        self.eta_uncert = eta_uncert
        self.bayes_exp = bayes_exp

        # MPC rollout parameters
        self.mpc = mpc
        self.horizon = horizon
        self.num_rollouts = num_rollouts

        # Environment dimensions
        if 'Maze' in env.spec.id:
            self.state_dim = env.observation_space['observation'].shape[0]
        else:
            self.state_dim = env.observation_space.shape[0]

        self.action_dim = env.action_space.shape[0]
        self.max_action = float(env.action_space.high[0])
        self.agent_batch_size = self.horizon * self.num_rollouts if self.mpc else agent_batch_size
        self.agent_epochs = agent_epochs
        self.agent_update_freq = agent_update_freq
        self.agent_epochs = agent_epochs

        # Set the algorithm name
        if self.mpc:
            if bayes_exp:
                self.alg_name = 'BayesExp_MPC_SAC_' + uncertainty_method
            else:
                self.alg_name = 'MPC_SAC'
        elif self.bayes_exp:
            self.alg_name = 'BE_SAC_' + uncertainty_method
        else:
            self.alg_name = 'SAC'

        # Model parameters
        self.warmup_steps = warmup_steps
        self.model_update_freq = model_update_freq
        self.model_epochs = model_epochs
        self.model_batch_size = model_batch_size

        # Initialize SAC agent
        self.sac_agent = SAC(self.state_dim, self.action_dim, self.max_action, lr=lr_sac, gamma=gamma,
                             tau=tau, alpha=alpha)

        # Initialize ensemble model
        self.mod_reward = mod_reward
        self.ensemble_model = EnsembleModel(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            mod_reward=self.mod_reward,
            ensemble_size=10,
            lr=lr_model
        ).to(self.device)

        # Initialize replay buffers
        self.real_buffer = ReplayBuffer(1_024)
        self.model_buffer = ReplayBuffer(1_024)

        # Initialize metrics tracking
        self.total_steps, self.episode = 0, 0
        self.count, self.solved_step, self.perc_states = None, None, None
        self.uncertainty_normalizer = RollingUncertaintyNormalizer(device=self.device, clip_range=(-5.0, 5.0))

    def update_model(self):
        """Update the dynamics model using collected experience"""

        model_loss = 0
        for _ in range(self.model_epochs):

            if len(self.real_buffer) < self.model_batch_size:
                model_batch = len(self.real_buffer)
            else:
                model_batch = self.model_batch_size

            # Sample batch from real buffer
            states, actions, rewards, diff_states, _ = self.real_buffer.sample(model_batch)

            # Prepare tensors
            states = torch.FloatTensor(states).to(self.device)
            actions = torch.FloatTensor(actions).to(self.device)
            rewards = torch.FloatTensor(rewards).to(self.device).unsqueeze(1)
            diff_states = torch.FloatTensor(diff_states).to(self.device)

            # Train the ensemble model
            batch_loss = self.ensemble_model.train_models(states, actions, diff_states, rewards)
            model_loss += batch_loss

        # Clear the buffer?
        # self.real_buffer.clear()

        return model_loss / self.model_epochs

    def imaginary_rollout(self, initial_state, horizon=None):
        """Generate imaginary trajectories using the ensemble model and DDQN policy"""
        initial_state = copy.deepcopy(initial_state)

        raw_uncert_measures, norm_uncert_measures = [], []
        total_rewards = []

        if horizon is None:
            horizon = self.horizon

        # Outer loop: perform self.num_rollouts independent rollouts.
        for k in range(self.num_rollouts):
            # Start each rollout from the same initial state (deep copied)
            current_state = np.copy(initial_state)

            # Inner loop: rollout for the given horizon.
            for h in range(horizon):
                # Get action from PPO policy if policy is PPO otherwise random action
                action = self.sac_agent.select_action(current_state)
                action = torch.FloatTensor(action).to(self.device)

                # Prepare tensors and unsqueeze(0) only if 1-dimension tensor
                current_state_tensor = torch.FloatTensor(current_state).to(self.device)
                if current_state_tensor.dim() == 1:
                    current_state_tensor = current_state_tensor.unsqueeze(0)

                # Forward pass through ensemble model
                with torch.no_grad():
                    mean_preds, var_preds = self.ensemble_model.forward_all(current_state_tensor, action)

                # Next state should be a sample from the predicted distribution
                ensemble_idx = np.random.choice(self.ensemble_model.ensemble_size)
                model_pred = mean_preds[ensemble_idx, 0]

                if self.mod_reward:
                    next_state = model_pred[:, :-1]
                    reward = model_pred[:, -1].unsqueeze(0)
                else:
                    next_state = model_pred
                    reward = torch.zeros_like(model_pred[0]).unsqueeze(0)
                done_ = torch.zeros_like(reward)

                if self.bayes_exp:

                    # Compute uncertainty for Bayesian exploration
                    uncert_measure = compute_uncertainty(
                        current_state_tensor,
                        mean_preds.detach().clone(),
                        var_preds.detach().clone(),
                        self.ensemble_model,
                        method=self.uncertainty_method
                    )

                    # Average uncertainty across ensemble members and actions
                    raw_uncert_measures.append(uncert_measure.mean().item())

                    # Compute intrinsic rewards
                    normalized_uncert = self.uncertainty_normalizer.normalize(uncert_measure)
                    norm_uncert_measures.append(normalized_uncert.mean().item())

                    # Update the normalizer stats
                    self.uncertainty_normalizer.update(uncert_measure)

                    intrinsic_reward = normalized_uncert.mean(dim=0) if normalized_uncert.dim() > 1 else normalized_uncert
                    # intrinsic_reward = uncert_measure.mean(dim=0) if uncert_measure.dim() > 1 else uncert_measure
                    intrinsic_reward = intrinsic_reward.unsqueeze(1)
                    augmented_rewards = reward + self.eta_uncert * intrinsic_reward
                    total_rewards.append(augmented_rewards.mean().item())

                    reward = augmented_rewards

                # Store transition in model buffer
                self.model_buffer.push(
                    current_state_tensor.flatten().cpu().numpy(),
                    action.flatten().cpu().numpy(),
                    reward.flatten().cpu().numpy(),
                    next_state.flatten().cpu().numpy(),
                    done_.flatten().cpu().numpy()
                )

                # Update the current state
                current_state = next_state.detach().cpu().numpy()

        # Here log the uncertainty measure to wandb if enabled
        if self.bayes_exp:
            raw_uncertainty_measure = np.mean(raw_uncert_measures)
            norm_uncertainty_measure = np.mean(norm_uncert_measures)

            if self.wandb_flag:
                wandb.log({"Model/Raw Uncertainty Measure": raw_uncertainty_measure,
                           "Model/Norm Uncertainty Measure": norm_uncertainty_measure},
                          step=self.total_steps)

    def be_augment_rewards(self):

        # In the case of BE-SAC, we use the model to augment the real buffer rewards with intrinsic rewards
        # First, retrieve the entire real buffer
        states, actions, rewards, diff_states, dones = self.real_buffer.sample(len(self.real_buffer))

        # Prepare tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)

        # Forward pass through ensemble model
        with torch.no_grad():
            mean_preds, var_preds = self.ensemble_model.forward_all(states, actions)

        # Compute uncertainty for Bayesian exploration
        uncert_measure = compute_uncertainty(
            states,
            mean_preds.detach().clone(),
            var_preds.detach().clone(),
            self.ensemble_model,
            method=self.uncertainty_method
        )

        # Update the normalizer stats
        self.uncertainty_normalizer.update(uncert_measure)

        # Compute intrinsic rewards
        normalized_uncert = self.uncertainty_normalizer.normalize(uncert_measure)
        intrinsic_reward = normalized_uncert.mean(dim=1) if normalized_uncert.dim() > 1 else normalized_uncert

        # Augment rewards
        augmented_rewards = rewards + self.eta_uncert * intrinsic_reward

        # Update the real buffer with augmented rewards
        self.real_buffer.buffer = list(zip(states, actions, augmented_rewards, diff_states, dones))

    def get_final_buffer(self):

        # Here essentially if self.mpc we use only the imaginary buffer
        if self.mpc:
            # In the case of MPO we have filled the model buffer with imaginary rollouts
            final_buffer = self.model_buffer
        else:

            # Sample real buffer
            batch_size = self.agent_batch_size if len(self.real_buffer) > self.agent_batch_size else len(self.real_buffer)
            real_batch = random.sample(self.real_buffer.buffer, batch_size)

            # Put in the ReplayBuffer format
            final_buffer = ReplayBuffer(self.agent_batch_size + 1)
            final_buffer.buffer = real_batch
            final_buffer.position = len(final_buffer.buffer)

        return final_buffer

    def train(self, max_steps=200, num_eps=1_000, diffs=False):

        if self.wandb_flag:
            project_name = self.env.spec.id + "_MPC"
            wandb.init(project=project_name, sync_tensorboard=True,
                       name=f"{self.alg_name}_seed_{self.seed}_noise_{self.noise_std}_time_{time.time()}",
                       config=self.__dict__,
                       group=f"{self.alg_name}_noise_{self.noise_std}",
                       dir='/tmp')

        cumulative_reward = 0
        for episode in range(num_eps):

            state, _ = self.env.reset()
            if 'Maze' in self.env.spec.id:
                goal_pos = state['desired_goal']
                state = state['observation']
            else:
                state, goal_pos = state, None
            episode_reward, episode_steps = 0, 0

            for step in range(max_steps):

                if self.total_steps > self.warmup_steps:
                    action = self.sac_agent.select_action(state).flatten()
                else:
                    action = self.env.action_space.sample()

                denoise_next_state, reward, done, truncated, _ = self.env.step(action)
                if 'Maze' in self.env.spec.id:
                    denoise_next_state = denoise_next_state['observation']

                # Model the diff
                if self.noise_std > 0:
                    next_state = add_transition_noise(denoise_next_state, self.noise_std)
                else:
                    next_state = denoise_next_state

                outputs = next_state - state if diffs else next_state

                # Store transition for model training. And update normalizers
                self.real_buffer.push(state, action, reward, outputs, done)
                self.ensemble_model.normalizer.update(state, action, outputs, reward)

                state = next_state

                # 2) Second chunk: train the dynamics model
                if (self.mpc or self.bayes_exp) and self.total_steps % self.model_update_freq == 0 and self.total_steps >= (self.warmup_steps // 2):
                    model_loss = self.update_model()

                if self.total_steps >= self.warmup_steps:

                    # 3) Third chunk: Generate imaginary rollouts and train the SAC agent on them
                    if self.total_steps % self.agent_update_freq == 0:

                        # In MPC we need to generate imaginary rollouts at each step starting from the current state
                        if self.mpc:
                            self.imaginary_rollout(state)

                        if self.bayes_exp:
                            self.be_augment_rewards()

                        for _ in range(self.agent_epochs):
                            final_buffer = self.get_final_buffer()
                            critic_loss, actor_loss, alpha_loss = self.sac_agent.update(final_buffer,
                                                                                        self.agent_batch_size)
                        self.model_buffer.clear()

                # 4) Logging and Printing
                if self.wandb_flag:
                    wandb.log({
                        "Train/Episode Reward": episode_reward,
                        "Train/Cumulative Reward": cumulative_reward,
                        "Train/Episode Length": episode_steps,
                        "Env/Global Step": self.total_steps,
                        "Env/Pos_X": state[0],
                        "Env/Pos_Y": state[1],
                        "Train/Model Loss": model_loss if 'model_loss' in locals() else 0,
                        "Train/Critic Loss": critic_loss if 'critic_loss' in locals() else 0,
                        "Train/Actor Loss": actor_loss if 'actor_loss' in locals() else 0,
                        "Train/Alpha Loss": alpha_loss if 'alpha_loss' in locals() else 0,
                        "Train/Episode Number": self.episode
                    }, step=self.total_steps)  # Log at every episode as episode may be truncated

                if self.total_steps % 100 == 0:
                    print(f"Step: {self.total_steps}, Agent Pos: {state[:2]}, Goal Pos: {goal_pos}")

                episode_reward += reward
                cumulative_reward += reward

                episode_steps += 1
                self.total_steps += 1

                if done or truncated:
                    # Log number of steps to solve the environment
                    if self.wandb_flag:
                        self.solved_step = step
                        wandb.log({"Solved/Step": step}, step=self.total_steps)

                        if done:
                            # Then log the final episode reward for the remaining amount of steps in the episode
                            for _ in range(step, max_steps):
                                wandb.log({"Train/Episode Reward": episode_reward,
                                           "Train/Cumulative Reward": cumulative_reward,
                                           "Train/Episode Length": episode_steps,
                                           "Env/Global Step": self.total_steps,
                                           "Env/Pos_X": state[0],
                                           "Env/Pos_Y": state[1],
                                           "Train/Model Loss": model_loss if 'model_loss' in locals() else 0,
                                           "Train/Critic Loss": critic_loss if 'critic_loss' in locals() else 0,
                                           "Train/Actor Loss": actor_loss if 'actor_loss' in locals() else 0,
                                           "Train/Alpha Loss": alpha_loss if 'alpha_loss' in locals() else 0,
                                           "Train/Episode Number": self.episode
                                           }, step=self.total_steps)

                    break

            self.episode += 1

        if self.wandb_flag:
            wandb.finish()
        return self.perc_states, self.solved_step
