import numpy as np
import torch
import torch.nn.functional as F
import wandb

# Import existing implementations
from algs.ddqn import DDQN, ReplayBuffer
from dynamics.deep_ens import EnsembleModel
from dynamics.utils import compute_uncertainty, RollingUncertaintyNormalizer
from mount_car.utils.utils import mountain_car_noise


class BE_MPC_DDQN:
    def __init__(self,
                 env,
                 seed,
                 device,
                 uncertainty_method='IG',
                 policy='DDQN',
                 mod_reward=False,
                 noise_std=0.0,
                 noise_model='heteroskedastic',
                 wandb_flag=True,
                 eta_uncert=1.0,
                 num_rollouts=10,
                 horizon=100,
                 warmup_steps=50,
                 agent_update_freq=1,
                 model_update_freq=10,
                 agent_epochs=100,
                 model_epochs=100,
                 model_batch_size=256,
                 # agent_batch_size=256,
                 model_lr=1e-3,
                 ddqn_lr=0.0003,
                 gamma=0.99,
                 tau=0.005):

        self.env = env
        self.seed = seed
        self.wandb_flag = wandb_flag
        self.device = device
        self.noise_std = noise_std
        self.noise_model = noise_model
        self.uncertainty_method = uncertainty_method
        self.eta_uncert = eta_uncert

        # MPC rollout parameters
        self.horizon = horizon
        self.num_rollouts = num_rollouts

        # Agent parameters
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.agent_batch_size = self.horizon * self.num_rollouts
        self.agent_update_freq = agent_update_freq
        self.agent_epochs = agent_epochs

        # Model parameters
        self.warmup_steps = warmup_steps
        self.model_update_freq = model_update_freq
        self.model_epochs = model_epochs
        self.model_batch_size = model_batch_size

        # Initialize DDQN agent
        self.policy = policy
        if self.policy == 'DDQN':
            self.ddqn_agent = DDQN(
                self.state_dim,
                self.action_dim,
                lr=ddqn_lr,
                gamma=gamma,
                tau=tau,
                epsilon_start=0.3,
                epsilon_end=0.01,
            )

        # Initialize ensemble model
        self.mod_reward = mod_reward
        self.ensemble_model = EnsembleModel(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            mod_reward=self.mod_reward,
            ensemble_size=10,
            lr=model_lr
        ).to(self.device)

        # Initialize replay buffers
        self.real_buffer = ReplayBuffer(1_000)
        self.model_buffer = ReplayBuffer(1_000)

        # Initialize metrics tracking
        self.total_steps, self.episode = 0, 0
        self.count, self.solved_step, self.perc_states = None, None, None
        self.uncertainty_normalizer = RollingUncertaintyNormalizer(device=self.device, clip_range=(-5.0, 5.0))

    def update_model(self):
        """Update the dynamics model using collected experience"""

        model_loss = 0
        for _ in range(self.model_epochs):

            if len(self.real_buffer) < self.model_batch_size:
                model_batch = len(self.real_buffer)
            else:
                model_batch = self.model_batch_size

            # Sample batch from real buffer
            states, actions, rewards, diff_states, _ = self.real_buffer.sample(model_batch)

            # Prepare tensors
            states = torch.FloatTensor(states).to(self.device)
            actions = torch.LongTensor(actions).to(self.device).unsqueeze(1)
            actions = F.one_hot(actions, num_classes=self.action_dim).squeeze(1)
            rewards = torch.FloatTensor(rewards).to(self.device).unsqueeze(1)
            diff_states = torch.FloatTensor(diff_states).to(self.device)

            # Train the ensemble model
            batch_loss = self.ensemble_model.train_models(states, actions, diff_states, rewards)
            model_loss += batch_loss

        # Clear the real buffer after training
        self.real_buffer.clear()

        return model_loss / self.model_epochs

    def update_policy(self):
        """Update DDQN policy using model buffer"""
        if len(self.model_buffer) < self.agent_batch_size:
            ag_batch_size = len(self.model_buffer)
        else:
            ag_batch_size = self.agent_batch_size

        q_loss = 0
        for _ in range(self.agent_epochs):
            q_loss = self.ddqn_agent.update(self.model_buffer, ag_batch_size)

        # Clear model buffer
        self.model_buffer.clear()

        return q_loss

    def imaginary_rollout(self, initial_state, horizon=None):
        """Generate imaginary trajectories using the ensemble model and DDQN policy"""
        import copy
        initial_state = copy.deepcopy(initial_state)

        raw_uncert_measures, norm_uncert_measures = [], []
        total_rewards = []

        if horizon is None:
            horizon = self.horizon

        # Outer loop: perform self.num_rollouts independent rollouts.
        for k in range(self.num_rollouts):
            # Start each rollout from the same initial state (deep copied)
            current_state = np.copy(initial_state)

            # Inner loop: rollout for the given horizon.
            for h in range(horizon):

                # Get actions from DDQN policy if policy is DDQN otherwise random actions
                action = self.ddqn_agent.select_action(current_state)

                # One hot encode the actions
                action_tensor = torch.LongTensor([action]).to(self.device).unsqueeze(1)
                actions_onehot = F.one_hot(action_tensor, num_classes=self.action_dim).squeeze(1).float()

                # Convert current_state into a tensor with a batch dimension.
                current_state_tensor = torch.FloatTensor(current_state).to(self.device).unsqueeze(0)

                # Forward pass through ensemble model
                with torch.no_grad():
                    mean_preds, var_preds = self.ensemble_model.forward_all(current_state_tensor, actions_onehot)


                # HEREEEEEEEEEEEE


                # Next state should be a sample from the ensemble model predictions for each rollout
                ensemble_indices = np.random.choice(self.ensemble_model.ensemble_size, self.num_rollouts)
                model_pred = mean_preds[ensemble_indices, np.arange(self.num_rollouts)]
                if self.mod_reward:
                    next_states = model_pred[:, :-1]
                    rewards = model_pred[:, -1].unsqueeze(1)
                else:
                    next_states = model_pred
                    rewards = torch.zeros_like(next_states[:, 0]).unsqueeze(1)
                dones = torch.zeros_like(rewards)

                # Compute uncertainty for the selected method only
                uncert_measure = compute_uncertainty(
                    initial_states,
                    mean_preds.detach().clone(),
                    var_preds.detach().clone(),
                    self.ensemble_model,
                    method=self.uncertainty_method
                )

                # Average uncertainty across ensemble members and actions
                raw_uncert_measures.append(uncert_measure.mean().item())

                # Compute intrinsic rewards
                normalized_uncert = self.uncertainty_normalizer.normalize(uncert_measure)
                norm_uncert_measures.append(normalized_uncert.mean().item())

                # Update the normalizer stats
                self.uncertainty_normalizer.update(uncert_measure)

                intrinsic_reward = normalized_uncert.mean(dim=0) if normalized_uncert.dim() > 1 else normalized_uncert
                intrinsic_reward = intrinsic_reward.unsqueeze(1)
                augmented_rewards = rewards + self.eta_uncert * intrinsic_reward
                total_rewards.append(augmented_rewards.mean().item())

                # Store transitions in model buffer
                for i in range(self.num_rollouts):
                    self.model_buffer.push(
                        initial_states[i].cpu().numpy(),
                        actions[i],
                        augmented_rewards[i].item(),
                        next_states[i].detach().cpu().numpy(),
                        dones[i].item()
                    )

                initial_states = next_states.detach().cpu().numpy()

        return np.mean(raw_uncert_measures), np.mean(total_rewards)

    def compute_uncertainty_all(self, method='Error'):

        """ Compute uncertainties for all s,a,s' transitions in the real buffer to augment rewards"""

        # Compute uncertainties for all transitions in the real buffer
        states, actions, rewards, outputs, _ = self.real_buffer.sample(len(self.real_buffer))  # Sample all transitions

        # Prepare tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device).unsqueeze(1)
        actions = F.one_hot(actions, num_classes=self.action_dim).squeeze(1)
        outputs = torch.FloatTensor(outputs).to(self.device)

        # Forward pass through the ensemble model
        with torch.no_grad():
            mean_preds, var_preds = self.ensemble_model.forward(states, actions)

        # Compute uncertainties with the chosen method
        uncert_measure = compute_uncertainty(states,
                                             mean_preds.detach().clone(),
                                             var_preds.detach().clone(),
                                             self.ensemble_model,
                                             method=method,
                                             diff_states=outputs)

        # Average uncertainty across ensemble members and actions (for reporting only)
        raw_uncert_log = uncert_measure.mean().item()

        # Normalize uncertainties
        normalized_uncert = self.uncertainty_normalizer.normalize(uncert_measure)
        norm_uncert_log = normalized_uncert.mean().item()
        self.uncertainty_normalizer.update(uncert_measure)  # Update normalizer

        # Compute intrinsic rewards
        intrinsic_reward = normalized_uncert.mean(dim=0) if normalized_uncert.dim() > 1 else normalized_uncert
        intrinsic_reward = intrinsic_reward.numpy()
        # augmented_rewards = rewards + self.eta_uncert * intrinsic_reward
        augmented_rewards = intrinsic_reward

        # Replace the rewards in the real buffer with augmented rewards - but we have to make it a list first
        augmented_rewards = augmented_rewards.tolist()
        self.real_buffer.rewards = augmented_rewards

        return raw_uncert_log, norm_uncert_log  # Return the raw and norm uncertainty measures for logging

    def compute_count_perc(self, eval_states, denoise_next_state):
        """Compute the percentage of states visited"""
        diff_abs = np.abs(eval_states - denoise_next_state)
        within_bounds = np.all(diff_abs <= 0.05, axis=1)
        close_indices = np.where(within_bounds)[0]

        self.count[(self.total_steps - self.warmup_steps - 1):, close_indices] = 1.0
        perc_states = np.sum(self.count, axis=1) / eval_states.shape[0]
        self.perc_states[self.total_steps - self.warmup_steps - 1] = perc_states[-1]

    def train(self, max_steps=200, num_eps=1_000, diffs=True):

        if self.wandb_flag:
            wandb.init(
                project="MC_EXPLORE",
                name=f"MPC_{self.policy}_{self.uncertainty_method}_noise_{self.noise_model}_seed_{self.seed}",
                config=self.__dict__,
                group=f"MPC_{self.policy}_{self.uncertainty_method}_noise_{self.noise_model}",
                dir='/tmp'
            )

        # Discretize a continuous state space in a grid made of combinations of the two states
        eval_states = np.array([[x, y] for x in np.linspace(-1.2, 0.6, 20) for y in np.linspace(-0.07, 0.07, 20)])

        # Initialize metrics every episode
        self.count = np.zeros((max_steps * num_eps, eval_states.shape[0]))
        self.perc_states = np.zeros(max_steps * num_eps)

        # 1) Loop Episodes (1 episode only, but can be extended to multiple episodes)
        for _ in range(num_eps):

            # 2) Loop Steps (200) - default max_steps
            state, _ = self.env.reset()
            episode_reward, episode_steps = 0, 0
            solved = False

            for step in range(max_steps):

                # Select action using DDQN policy after warmup steps
                if self.total_steps < self.warmup_steps or self.policy == 'Random':
                    action = self.env.action_space.sample()
                else:
                    action = self.ddqn_agent.select_action(state)

                # Execute action and add noise
                denoise_next_state, reward, done, truncated, _ = self.env.step(action)
                next_state, local_noise, position_factor, velocity_factor = mountain_car_noise(next_state=denoise_next_state,
                                                                                               noise_fraction=self.noise_std,
                                                                                               current_state=state,
                                                                                               noise_model=self.noise_model)

                outputs = next_state - state if diffs else next_state

                # Store transition for model training. And update normalizers
                self.real_buffer.push(state, action, reward, outputs, done)
                self.ensemble_model.normalizer.update(state, action, outputs, reward)

                state = denoise_next_state

                if step % self.model_update_freq == 0 and self.total_steps >= (self.warmup_steps // 2):
                    model_loss = self.update_model()

                # Warm-up steps necessary to warm-start the model - now we can start updating the agent
                if self.total_steps >= self.warmup_steps:

                    if self.total_steps % self.agent_update_freq == 0:

                        if self.policy != 'DDQN':
                            # Compute uncertainties and intrinsic rewards for all transitions in the real buffer
                            avg_raw_uncert, avg_norm_uncert = self.compute_uncertainty_all(method=self.uncertainty_method)

                        q_loss = 0
                        if self.policy == 'DDQN':
                            # Update the DDQN agent
                            avg_raw_uncert, avg_intrin_rew = self.imaginary_rollout(state, horizon=self.horizon)
                            q_loss = self.update_policy()

                        # Log metrics
                        if self.wandb_flag:
                            wandb.log({
                                "Models/Model Loss": model_loss if 'model_loss' in locals() else 0,
                                "Models/Policy Loss": q_loss if self.policy == 'DDQN' else 0,
                                f"Uncertainty/Raw_Avg_{self.uncertainty_method}": avg_raw_uncert if 'avg_raw_uncert' in locals() else 0,
                                f"Uncertainty/Norm_Avg_{self.uncertainty_method}": avg_norm_uncert if self.policy != 'DDQN' else 0,
                                "Uncertainty/Intrinsic Reward": avg_intrin_rew if 'avg_intrin_rew' in locals() else 0,
                            }, step=self.total_steps)

                # Update the visit count if the corresponding state in the grid has been visited
                self.compute_count_perc(eval_states, denoise_next_state)

                # Log episode metrics
                if self.wandb_flag:
                    wandb.log({
                        "Env/Step": self.total_steps,
                        "Env/Real Reward": reward,
                        "Env/Position": denoise_next_state[0],
                        "Env/Velocity": denoise_next_state[1],
                        "Env/Perc States Visited": self.perc_states[self.total_steps - self.warmup_steps - 1],
                        "Env/Local Noise": local_noise,
                        "Env/Position Factor": position_factor,
                        "Env/Velocity Factor": velocity_factor
                    }, step=self.total_steps)

                # Check if environment is solved
                if not solved and denoise_next_state[0] >= 0.5:
                    solved = True
                    self.solved_step = step
                    if self.wandb_flag:
                        wandb.log({"Env/Solved at step": step})

                    # Before breaking, fill the rest of the perc state array with the most recent value
                    self.perc_states[self.total_steps:] = self.perc_states[self.total_steps - 1]

                    # break

                # Update state and metrics
                state = next_state
                episode_reward += reward
                self.total_steps += 1

                # Print step info
                if self.total_steps % 10 == 0:
                    print(f"Step: {self.total_steps}, Pos: {denoise_next_state[0]:.2f}, Vel: {denoise_next_state[1]:.2f}")

        if self.wandb_flag:
            wandb.finish()
        return self.perc_states, self.solved_step
