import numpy as np
import torch
import torch.nn.functional as F
import wandb


# Import existing implementations
from algs.ppo import PPO, ReplayBuffer
from dynamics.deep_ens import EnsembleModel
from dynamics.utils import compute_uncertainty, RollingUncertaintyNormalizer
from mount_car.utils.utils import mountain_car_noise


class BE_PPO:
    def __init__(self,
                 env,
                 seed,
                 device,
                 wandb_flag=True,
                 uncertainty_method='IG',
                 policy='PPO',
                 mod_reward=False,
                 noise_std=0.0,
                 noise_model='heteroskedastic',
                 eta_uncert=1.0,
                 warmup_steps=64,
                 agent_update_freq=4,
                 model_update_freq=16,
                 agent_epochs=100,
                 model_epochs=100,
                 model_batch_size=256,
                 agent_batch_size=256,
                 model_lr=1e-3,
                 ppo_lr=0.001):

        self.env = env
        self.seed = seed
        self.device = device
        self.wandb_flag = wandb_flag
        self.noise_std = noise_std
        self.noise_model = noise_model
        self.uncertainty_method = uncertainty_method
        self.eta_uncert = eta_uncert

        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.agent_batch_size = agent_batch_size
        self.agent_update_freq = agent_update_freq
        self.agent_epochs = agent_epochs

        self.warmup_steps = warmup_steps
        self.model_update_freq = model_update_freq
        self.model_epochs = model_epochs
        self.model_batch_size = model_batch_size

        # Initialize PPO agent if policy is PPO, otherwise random policy
        self.policy = policy
        self.ppo_agent = PPO(state_dim=self.state_dim, action_dim=self.action_dim, lr_actor=ppo_lr, lr_critic=ppo_lr,
                             gamma=0.99, K_epochs=agent_epochs, eps_clip=0.2, has_continuous_action_space=False)

        # Initialize ensemble model
        self.mod_reward = mod_reward
        self.ensemble_model = EnsembleModel(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            mod_reward=self.mod_reward,
            ensemble_size=5,
            lr=model_lr
        ).to(self.device)

        # Initialize replay buffers
        self.real_buffer = ReplayBuffer(1_000)

        # Initialize metrics tracking
        self.total_steps, self.episode = 0, 0
        self.raw_uncertainty_metrics = None
        self.count = None
        self.solved_step = None
        self.perc_states = None
        self.uncertainty_normalizer = RollingUncertaintyNormalizer(device=self.device, clip_range=(-5.0, 5.0))

    def update_model(self):
        """Update the dynamics model using collected experience"""

        model_loss = 0
        for _ in range(self.model_epochs):

            if len(self.real_buffer) < self.model_batch_size:
                model_batch = len(self.real_buffer)
            else:
                model_batch = self.model_batch_size

            # Sample batch from real buffer
            states, actions, rewards, diff_states, _ = self.real_buffer.sample(model_batch)

            # Prepare tensors
            states = torch.FloatTensor(states).to(self.device)
            actions = torch.LongTensor(actions).to(self.device).unsqueeze(1)
            actions = F.one_hot(actions, num_classes=self.action_dim).squeeze(1)
            rewards = torch.FloatTensor(rewards).to(self.device).unsqueeze(1)
            diff_states = torch.FloatTensor(diff_states).to(self.device)

            # Train the ensemble model
            batch_loss = self.ensemble_model.train_models(states, actions, diff_states, rewards)
            model_loss += batch_loss

        return model_loss / self.model_epochs

    def compute_uncertainty_all(self, method='Error'):

        """ Compute uncertainties for all s,a,s' transitions in the real buffer to augment rewards"""

        # Compute uncertainties for all transitions in the real buffer
        states, actions, rewards, diff_states, _ = self.real_buffer.sample(len(self.real_buffer))  # Sample all transitions

        # Prepare tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device).unsqueeze(1)
        actions = F.one_hot(actions, num_classes=self.action_dim).squeeze(1)
        diff_states = torch.FloatTensor(diff_states).to(self.device)

        # Forward pass through the ensemble model
        with torch.no_grad():
            mean_preds, var_preds = self.ensemble_model.forward(states, actions)

        # Compute uncertainties with the chosen method
        uncert_measure = compute_uncertainty(states,
                                             mean_preds.detach().clone(),
                                             var_preds.detach().clone(),
                                             self.ensemble_model,
                                             method=method,
                                             diff_states=diff_states)

        # Average uncertainty across ensemble members and actions (for reporting only)
        raw_uncert_log = uncert_measure.mean().item()

        # Normalize uncertainties
        normalized_uncert = self.uncertainty_normalizer.normalize(uncert_measure)
        norm_uncert_log = normalized_uncert.mean().item()
        self.uncertainty_normalizer.update(uncert_measure)  # Update normalizer

        # Compute intrinsic rewards
        intrinsic_reward = normalized_uncert.mean(dim=0) if normalized_uncert.dim() > 1 else normalized_uncert
        intrinsic_reward = intrinsic_reward.numpy()
        # augmented_rewards = rewards + self.eta_uncert * intrinsic_reward
        augmented_rewards = intrinsic_reward

        # Replace the rewards in the real buffer with augmented rewards - but we have to make it a list first
        augmented_rewards = augmented_rewards.tolist()
        self.ppo_agent.buffer.rewards = augmented_rewards

        return raw_uncert_log, norm_uncert_log  # Return the raw and norm uncertainty measures for logging

    def compute_count_perc(self, eval_states, denoise_next_state):
        """Compute the percentage of states visited"""
        diff_abs = np.abs(eval_states - denoise_next_state)
        within_bounds = np.all(diff_abs <= 0.05, axis=1)
        close_indices = np.where(within_bounds)[0]

        self.count[(self.total_steps - self.warmup_steps - 1):, close_indices] = 1.0
        perc_states = np.sum(self.count, axis=1) / eval_states.shape[0]
        self.perc_states[self.total_steps - self.warmup_steps - 1] = perc_states[-1]

    def train(self, max_steps=200, num_eps=1_000, diffs=True):

        if self.wandb_flag:
            wandb.init(
                project="MC_EXPLORE",
                name=f"{self.policy}_{self.uncertainty_method}_noise_{self.noise_model}_seed_{self.seed}",
                config=self.__dict__,
                group=f"{self.policy}_{self.uncertainty_method}_noise_{self.noise_model}",
                dir='/tmp'
            )

        # Discretize a continuous state space in a grid made of combinations of the two states
        eval_states = np.array([[x, y] for x in np.linspace(-1.2, 0.6, 20) for y in np.linspace(-0.07, 0.07, 20)])

        # Initialize metrics
        self.raw_uncertainty_metrics = np.zeros(num_eps * max_steps - self.warmup_steps)
        self.count = np.zeros((num_eps * max_steps - self.warmup_steps, eval_states.shape[0]))
        self.perc_states = np.zeros(num_eps * max_steps - self.warmup_steps)

        # 1) Loop Episodes (1_000)
        for _ in range(num_eps):

            # 2) Loop Steps (200) - default max_steps
            state, _ = self.env.reset()
            episode_reward, episode_steps = 0, 0
            solved = False

            for step in range(max_steps):

                # Select action using PPO policy after warm-up steps
                if self.total_steps < self.warmup_steps:
                    action = self.env.action_space.sample()
                else:
                    action = self.ppo_agent.select_action(state)

                # Execute action and add noise
                denoise_next_state, reward, done, truncated, _ = self.env.step(action)

                # Add rewards to agent's buffer and done flag
                self.ppo_agent.buffer.rewards.append(reward)
                self.ppo_agent.buffer.is_terminals.append(done)

                next_state, local_noise, position_factor, velocity_factor = mountain_car_noise(next_state=denoise_next_state,
                                                                                               noise_fraction=self.noise_std,
                                                                                               current_state=state,
                                                                                               noise_model=self.noise_model)
                if diffs:
                    outputs = next_state - state
                else:
                    outputs = next_state

                # Store transition for model training. And update normalizers
                self.real_buffer.push(state, action, reward, outputs, done)
                self.ensemble_model.normalizer.update(state, action, outputs, reward)

                state = denoise_next_state

                if step % self.model_update_freq == 0 and self.total_steps >= (self.warmup_steps // 2):
                    model_loss = self.update_model()

                # If we have passed the warmup steps, update the model, compute intrinsic rewards and update policy
                if self.total_steps >= self.warmup_steps:

                    if self.total_steps % self.agent_update_freq == 0:

                        # Compute uncertainties and intrinsic rewards for all transitions in the real buffer
                        avg_raw_uncert, avg_norm_uncert = self.compute_uncertainty_all(method=self.uncertainty_method)

                        # Update the PPO agent
                        ppo_loss = self.ppo_agent.update()

                        if self.wandb_flag:
                            # Log metrics
                            wandb.log({
                                "Models/Model Loss": model_loss if 'model_loss' in locals() else 0,
                                "Models/Policy Loss": ppo_loss if self.policy == 'PPO' else 0,
                                f"Uncertainty/Raw_Avg_{self.uncertainty_method}": avg_raw_uncert,
                                f"Uncertainty/Norm_Avg_{self.uncertainty_method}": avg_norm_uncert
                            }, step=self.total_steps)

                # Update the visit count if the corresponding state in the grid has been visited
                self.compute_count_perc(eval_states, denoise_next_state)

                # Log episode metrics
                if self.wandb_flag:
                    wandb.log({
                        "Env/Step": self.total_steps,
                        "Env/Real Reward": reward,
                        "Env/Position": denoise_next_state[0],
                        "Env/Velocity": denoise_next_state[1],
                        "Env/Perc States Visited": self.perc_states[self.total_steps - self.warmup_steps - 1],
                        "Env/Local Noise": local_noise,
                        "Env/Position Factor": position_factor,
                        "Env/Velocity Factor": velocity_factor
                    }, step=self.total_steps)

                # Check if environment is solved
                if not solved and denoise_next_state[0] == 0.6:
                    solved = True
                    self.solved_step = step
                    if self.wandb_flag:
                        wandb.log({"Env/Solved at step": step})

                    # Before breaking, fill the rest of the metrics with the last value
                    self.raw_uncertainty_metrics[(self.total_steps - self.warmup_steps):] = self.raw_uncertainty_metrics[self.total_steps - self.warmup_steps - 1]
                    self.perc_states[(self.total_steps - self.warmup_steps):] = self.perc_states[self.total_steps - self.warmup_steps - 1]

                    # break

                # Update state and metrics
                state = next_state
                episode_reward += reward
                self.total_steps += 1

                # Print step info
                if self.total_steps % 10 == 0:
                    print(f"Step: {self.total_steps}, Pos: {denoise_next_state[0]:.2f}, Vel: {denoise_next_state[1]:.2f}")

        if self.wandb_flag:
            wandb.finish()
        return self.perc_states, self.solved_step
