import os
import numpy as np
import torch
import torch.nn.functional as F
import gymnasium as gym
import wandb
import copy


# Import existing implementations
from algs.ppo import PPO, ReplayBuffer
# from dynamics.deep_ens import EnsembleModel
from dynamics.deep_ens import EnsembleModel
from dynamics.utils import compute_uncertainty, RollingUncertaintyNormalizer
from mount_car.utils.utils import mountain_car_noise


class BE_MPC_PPO:
    def __init__(self,
                 env,
                 seed,
                 device,
                 uncertainty_method='IG',
                 policy='PPO',
                 mod_reward=False,
                 noise_std=0.0,
                 noise_model='heteroskedastic',
                 wandb_flag=True,
                 eta_uncert=1.0,
                 warmup_steps=50,
                 agent_update_freq=1,
                 model_update_freq=10,
                 agent_epochs=100,
                 model_epochs=100,
                 model_batch_size=256,
                 # agent_batch_size=256,
                 num_rollouts=10,
                 horizon=100,
                 model_lr=1e-3,
                 ppo_lr=0.001):

        self.env = env
        self.seed = seed
        self.wandb_flag = wandb_flag
        self.device = device
        self.noise_std = noise_std
        self.noise_model = noise_model
        self.uncertainty_method = uncertainty_method
        self.eta_uncert = eta_uncert

        # MPC rollout parameters
        self.horizon = horizon
        self.num_rollouts = num_rollouts

        # Agent parameters
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.agent_batch_size = self.horizon * self.num_rollouts
        self.agent_update_freq = agent_update_freq
        self.agent_epochs = agent_epochs

        # Model parameters
        self.warmup_steps = warmup_steps
        self.model_update_freq = model_update_freq
        self.model_epochs = model_epochs
        self.model_batch_size = model_batch_size

        # Initialize PPO agent if policy is PPO, otherwise random policy
        self.policy = policy
        self.ppo_agent = PPO(state_dim=self.state_dim, action_dim=self.action_dim, lr_actor=ppo_lr, lr_critic=ppo_lr,
                             gamma=0.99, K_epochs=agent_epochs, eps_clip=0.2, has_continuous_action_space=False)

        # Initialize ensemble model
        self.mod_reward = mod_reward
        self.ensemble_model = EnsembleModel(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            # action_dim=1,
            mod_reward=self.mod_reward,
            ensemble_size=20,
            lr=model_lr
        ).to(self.device)

        # Initialize replay buffers
        self.real_buffer = ReplayBuffer(500)

        # Initialize metrics tracking
        self.total_steps, self.episode = 0, 0
        self.count, self.solved_step, self.perc_states = None, None, None
        self.uncertainty_normalizer = RollingUncertaintyNormalizer(device=self.device, clip_range=(-5.0, 5.0))

    def update_model(self):
        """Update the dynamics model using collected experience"""

        model_loss = 0
        for _ in range(self.model_epochs):

            if len(self.real_buffer) < self.model_batch_size:
                model_batch = len(self.real_buffer)
            else:
                model_batch = self.model_batch_size

            # Sample batch from real buffer
            states, actions, rewards, diff_states, _ = self.real_buffer.sample(model_batch)

            # Prepare tensors
            states = torch.FloatTensor(states).to(self.device)
            # actions = torch.FloatTensor(actions).to(self.device).unsqueeze(1)
            actions = torch.LongTensor(actions).to(self.device).unsqueeze(1)
            actions = F.one_hot(actions, num_classes=self.action_dim).squeeze(1)
            rewards = torch.FloatTensor(rewards).to(self.device).unsqueeze(1)
            diff_states = torch.FloatTensor(diff_states).to(self.device)

            # Train the ensemble model
            batch_loss = self.ensemble_model.train_models(states, actions, diff_states, rewards)
            model_loss += batch_loss

        # Clear the real buffer after training
        self.real_buffer.clear()

        return model_loss / self.model_epochs

    def update_policy(self):
        """Update PPO policy using model buffer"""

        # Update the PPO agent
        ppo_loss = self.ppo_agent.update()
        self.ppo_agent.buffer.clear()

        return ppo_loss

    # def imaginary_rollout(self, initial_state, horizon=None):
    #     """Generate imaginary trajectories using the ensemble model and PPO policy"""
    #     initial_state = copy.deepcopy(initial_state)
    #
    #     # Copy the state for K rollouts
    #     initial_states = np.repeat(initial_state[None, :], self.num_rollouts, axis=0)
    #     raw_uncert_measures, norm_uncert_measures = [], []
    #     total_rewards = []
    #
    #     if horizon is None:
    #         horizon = self.horizon
    #
    #     for h in range(horizon):
    #
    #         # Get actions from PPO policy if policy is PPO otherwise random actions
    #         actions = [self.ppo_agent.select_action(initial_state) for initial_state in initial_states]
    #
    #         # One hot encode the actions
    #         # actions_onehot = torch.FloatTensor(actions).to(self.device).unsqueeze(1)
    #         actions_onehot = torch.LongTensor(actions).to(self.device).unsqueeze(1)
    #         actions_onehot = F.one_hot(actions_onehot, num_classes=self.action_dim).squeeze(1)
    #
    #         # Forward pass through ensemble model
    #         initial_states = torch.FloatTensor(initial_states).to(self.device)
    #         with torch.no_grad():
    #             mean_preds, var_preds = self.ensemble_model.forward_all(initial_states, actions_onehot)
    #
    #         # Next state should be a sample from the ensemble model predictions for each rollout
    #         ensemble_indices = np.random.choice(self.ensemble_model.ensemble_size, self.num_rollouts)
    #         # model_pred = mean_preds.mean(dim=0)
    #         model_pred = mean_preds[ensemble_indices, np.arange(self.num_rollouts)]
    #         if self.mod_reward:
    #             next_states = model_pred[:, :-1]
    #             rewards = model_pred[:, -1].unsqueeze(1)
    #         else:
    #             next_states = model_pred
    #             rewards = torch.zeros_like(next_states[:, 0]).unsqueeze(1)
    #
    #         # Compute uncertainty for the selected method only
    #         uncert_measure = compute_uncertainty(
    #             initial_states,
    #             mean_preds.detach().clone(),
    #             var_preds.detach().clone(),
    #             self.ensemble_model,
    #             method=self.uncertainty_method
    #         )
    #
    #         # Average uncertainty across ensemble members and actions
    #         raw_uncert_measures.append(uncert_measure.mean().item())
    #
    #         # Compute intrinsic rewards
    #         # normalized_uncert = self.uncertainty_normalizer.normalize(uncert_measure)
    #         # norm_uncert_measures.append(normalized_uncert.mean().item())
    #
    #         # Update the normalizer stats
    #         # self.uncertainty_normalizer.update(uncert_measure)
    #         #
    #         # intrinsic_reward = normalized_uncert.mean(dim=0) if normalized_uncert.dim() > 1 else normalized_uncert
    #         intrinsic_reward = uncert_measure.mean(dim=0) if uncert_measure.dim() > 1 else uncert_measure
    #         intrinsic_reward = intrinsic_reward.unsqueeze(1)
    #         augmented_rewards = rewards + self.eta_uncert * intrinsic_reward
    #         total_rewards.append(augmented_rewards.mean().item())
    #
    #         # Store total rewards in the agent's buffer by appending one by one
    #         for i_roll in range(self.num_rollouts):
    #             self.ppo_agent.buffer.rewards.append(augmented_rewards[i_roll])
    #             if h == horizon - 1:
    #                 self.ppo_agent.buffer.is_terminals.append(True)
    #             else:
    #                 self.ppo_agent.buffer.is_terminals.append(False)
    #
    #         initial_states = next_states.detach().cpu().numpy()
    #
    #     return np.mean(raw_uncert_measures), np.mean(total_rewards)

    def imaginary_rollout(self, initial_state, horizon=None):
        """Generate imaginary trajectories using the ensemble model and PPO policy"""
        initial_state = copy.deepcopy(initial_state)

        raw_uncert_measures, norm_uncert_measures = [], []
        total_rewards = []

        if horizon is None:
            horizon = self.horizon

        # Outer loop: perform self.num_rollouts independent rollouts.
        for k in range(self.num_rollouts):
            # Start each rollout from the same initial state (deep copied)
            current_state = np.copy(initial_state)

            # Inner loop: rollout for the given horizon.
            for h in range(horizon):
                # Get action from PPO policy if policy is PPO otherwise random action
                action = self.ppo_agent.select_action(current_state)

                # One hot encode the actions
                # actions_onehot = torch.FloatTensor([action]).to(self.device).unsqueeze(1)
                action_tensor = torch.LongTensor([action]).to(self.device).unsqueeze(1)
                actions_onehot = F.one_hot(action_tensor, num_classes=self.action_dim).squeeze(1).float()

                # Convert current_state into a tensor with a batch dimension.
                current_state_tensor = torch.FloatTensor(current_state).to(self.device).unsqueeze(0)

                # Forward pass through ensemble model
                with torch.no_grad():
                    mean_preds, var_preds = self.ensemble_model.forward_all(current_state_tensor, actions_onehot)

                # Randomly choose one ensemble member for this rollout.
                ensemble_idx = np.random.choice(self.ensemble_model.ensemble_size)
                # For a single rollout, batch index is 0.
                model_pred = mean_preds[ensemble_idx, 0]

                if self.mod_reward:
                    next_states = model_pred[:-1]
                    rewards = model_pred[-1].unsqueeze(0)
                else:
                    next_states = model_pred
                    rewards = torch.zeros_like(model_pred[0]).unsqueeze(0)

                # Compute uncertainty for the selected method only.
                # Note: compute_uncertainty is assumed to work on batched inputs, so we pass our single state batch.
                uncert_measure = compute_uncertainty(
                    current_state_tensor,
                    mean_preds.detach().clone(),
                    var_preds.detach().clone(),
                    self.ensemble_model,
                    method=self.uncertainty_method
                )

                # Average uncertainty across ensemble members and actions.
                raw_uncert_measures.append(uncert_measure.mean().item())

                # Compute intrinsic rewards.
                # (The normalized version is commented out in your original code.)
                intrinsic_reward = uncert_measure.mean(dim=0) if uncert_measure.dim() > 1 else uncert_measure
                intrinsic_reward = intrinsic_reward.unsqueeze(1)
                augmented_reward = rewards + self.eta_uncert * intrinsic_reward
                total_rewards.append(augmented_reward.mean().item())

                # Store total rewards in the agent's buffer.
                self.ppo_agent.buffer.rewards.append(augmented_reward[0])
                if h == horizon - 1:
                    self.ppo_agent.buffer.is_terminals.append(True)
                else:
                    self.ppo_agent.buffer.is_terminals.append(False)

                # Update the current state for the next step in this rollout.
                current_state = next_states.detach().cpu().numpy()

        return np.mean(raw_uncert_measures), np.mean(total_rewards)

    def compute_count_perc(self, eval_states, denoise_next_state):
        """Compute the percentage of states visited"""
        diff_abs = np.abs(eval_states - denoise_next_state)
        within_bounds = np.all(diff_abs <= 0.05, axis=1)
        close_indices = np.where(within_bounds)[0]

        self.count[(self.total_steps - self.warmup_steps - 1):, close_indices] = 1.0
        perc_states = np.sum(self.count, axis=1) / eval_states.shape[0]
        self.perc_states[self.total_steps - self.warmup_steps - 1] = perc_states[-1]

    def train(self, max_steps=200, num_eps=1_000, diffs=True):

        if self.wandb_flag:
            wandb.init(
                project="MC_EXPLORE",
                name=f"MPC_{self.policy}_{self.uncertainty_method}_noise_{self.noise_model}_seed_{self.seed}",
                config=self.__dict__,
                group=f"MPC_{self.policy}_{self.uncertainty_method}_noise_{self.noise_model}",
                dir='/tmp'
            )

        # Discretize a continuous state space in a grid made of combinations of the two states
        eval_states = np.array([[x, y] for x in np.linspace(-1.2, 0.6, 20) for y in np.linspace(-0.07, 0.07, 20)])

        # Initialize metrics every episode
        self.count = np.zeros((max_steps * num_eps, eval_states.shape[0]))
        self.perc_states = np.zeros(max_steps * num_eps)

        # 1) Loop Episodes (1 episode only, but can be extended to multiple episodes)
        for _ in range(num_eps):

            # 2) Loop Steps (200) - default max_steps
            state, _ = self.env.reset()
            episode_reward, episode_steps = 0, 0
            solved = False

            for step in range(max_steps):

                # Select action using PPO policy after warm-up steps
                action = self.ppo_agent.select_action(state)

                # Execute action and add noise
                denoise_next_state, reward, done, truncated, _ = self.env.step(action)

                # Add rewards to agent's buffer and done flag
                self.ppo_agent.buffer.rewards.append(reward)
                self.ppo_agent.buffer.is_terminals.append(done)

                next_state, local_noise, position_factor, velocity_factor = mountain_car_noise(next_state=denoise_next_state,
                                                                                               noise_fraction=self.noise_std,
                                                                                               current_state=state,
                                                                                               noise_model=self.noise_model)

                outputs = next_state - state if diffs else next_state

                # Store transition for model training. And update normalizers
                self.real_buffer.push(state, action, reward, outputs, done)
                self.ensemble_model.normalizer.update(state, action, outputs, reward)

                state = denoise_next_state

                if step % self.model_update_freq == 0 and self.total_steps >= 10:
                    model_loss = self.update_model()

                # Warm-up steps necessary to warm-start the model - now we can start updating the agent
                if self.total_steps >= self.warmup_steps:

                    if self.total_steps % self.agent_update_freq == 0:

                        self.ppo_agent.buffer.clear()
                        avg_raw_uncert, avg_intrin_rew = self.imaginary_rollout(state, horizon=self.horizon)
                        ppo_loss = 0
                        ppo_loss = self.update_policy()

                        # Log metrics
                        if self.wandb_flag:
                            wandb.log({
                                "Models/Model Loss": model_loss if 'model_loss' in locals() else 0,
                                "Models/Policy Loss": ppo_loss if self.policy == 'PPO' else 0,
                                f"Uncertainty/Raw_Avg_{self.uncertainty_method}": avg_raw_uncert if 'avg_raw_uncert' in locals() else 0,
                                "Uncertainty/Intrinsic Reward": avg_intrin_rew if 'avg_intrin_rew' in locals() else 0,
                            }, step=self.total_steps)

                # Update the visit count if the corresponding state in the grid has been visited
                self.compute_count_perc(eval_states, denoise_next_state)

                # Log episode metrics
                if self.wandb_flag:
                    wandb.log({
                        "Env/Step": self.total_steps,
                        "Env/Real Reward": reward,
                        "Env/Position": denoise_next_state[0],
                        "Env/Velocity": denoise_next_state[1],
                        "Env/Perc States Visited": self.perc_states[self.total_steps - self.warmup_steps - 1],
                        "Env/Local Noise": local_noise,
                        "Env/Position Factor": position_factor,
                        "Env/Velocity Factor": velocity_factor
                    }, step=self.total_steps)

                # Check if environment is solved
                if not solved and denoise_next_state[0] >= 0.5:
                    solved = True
                    self.solved_step = step
                    if self.wandb_flag:
                        wandb.log({"Env/Solved at step": step})

                    # Before breaking, log to wandb all the remaining steps with the same values
                    for i in range(step, max_steps):
                        wandb.log({
                            "Env/Step": self.total_steps + i,
                            "Env/Real Reward": reward,
                            "Env/Position": denoise_next_state[0],
                            "Env/Velocity": denoise_next_state[1],
                            "Env/Perc States Visited": self.perc_states[self.total_steps - self.warmup_steps - 1],
                            "Env/Local Noise": local_noise,
                            "Env/Position Factor": position_factor,
                            "Env/Velocity Factor": velocity_factor
                        }, step=self.total_steps + i)

                    break

                # Update state and metrics
                state = next_state
                episode_reward += reward
                self.total_steps += 1

                # Print step info
                if self.total_steps % 10 == 0:
                    print(f"Step: {self.total_steps}, Pos: {denoise_next_state[0]:.2f}, Vel: {denoise_next_state[1]:.2f}")

        if self.wandb_flag:
            wandb.finish()
        return self.perc_states, self.solved_step
