# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
from collections import OrderedDict
import dataclasses
import logging
import random

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

from url_benchmark.dmc import TimeStep
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark import utils
from .fb_modules import mlp
from .rl_dynamics import DiscreteMBRLAgent, DiscreteMBRLAgentConfig
from url_benchmark import goals as _goals

logger = logging.getLogger(__name__)
MetaDict = tp.Mapping[str, np.ndarray]


@dataclasses.dataclass
class WorldModelConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.world_models.WorldModelAgent"
    name: str = "world_model"
    reward_free: bool = omegaconf.II("reward_free")
    custom_reward: tp.Optional[str] = omegaconf.II("custom_reward")
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    critic_target_tau: float = 0.01
    update_every_steps: float = 2
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING
    hidden_dim: int = 512
    feature_dim: int = 512
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)"
    stddev_clip: float = 0.3  # 1.0
    nstep: int = 1
    batch_size: int = 1024  # 256 for pixels
    init_critic: bool = True
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    fb_reward: bool = False
    future_ratio: float = 0
    preprocess: bool = False
    add_trunk: bool = False
    supervised: bool = True
    rl_training_steps: int = 25000

    rl_config = DiscreteMBRLAgentConfig()

cs = ConfigStore.instance()
cs.store(group="agent", name="world_model", node=WorldModelConfig)


class OneHotDynamicsModel(nn.Module):
    """
    Predicts single step dynamics with one-hot observations and actions.
    Input: one-hot observation + one-hot action
    Output: one-hot next observation
    """
    
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        
        # Input dimension is obs_dim + action_dim (concatenated one-hot vectors)
        input_dim = obs_dim + action_dim
        
        # MLP to predict next observation
        self.dynamics_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            nn.init.constant_(module.bias, 0.0)
    
    def action_to_onehot(self, action: torch.Tensor) -> torch.Tensor:
        """Convert discrete action to one-hot encoding."""
        # action should be shape (batch_size,) with integer values
        batch_size = action.shape[0]
        action_onehot = torch.zeros(batch_size, self.action_dim, device=action.device)
        action_onehot.scatter_(1, action.long().unsqueeze(1), 1)
        return action_onehot
    
    def forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to predict next observation.
        
        Args:
            obs: One-hot observation tensor of shape (batch_size, obs_dim)
            action: Discrete action tensor of shape (batch_size,) or one-hot of shape (batch_size, action_dim)
        
        Returns:
            next_obs_logits: Logits for next observation of shape (batch_size, obs_dim)
        """
        # Convert action to one-hot if it's discrete
        if action.dim() == 1 or (action.dim() == 2 and action.shape[1] == 1):
            action_onehot = self.action_to_onehot(action.squeeze())
        else:
            action_onehot = action
        
        # Concatenate observation and action
        obs_action = torch.cat([obs, action_onehot], dim=1)
        
        # Predict next observation logits
        next_obs_logits = self.dynamics_net(obs_action)
        
        return next_obs_logits
    
    def predict(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Predict next observation with softmax.
        
        Returns:
            next_obs_probs: Probability distribution over next observations
        """
        next_obs_logits = self.forward(obs, action)
        return F.softmax(next_obs_logits, dim=1)
    
    def predict_deterministic(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Predict next observation deterministically (argmax).
        
        Returns:
            next_obs_onehot: One-hot next observation
        """
        next_obs_logits = self.forward(obs, action)
        next_obs_idx = torch.argmax(next_obs_logits, dim=1)
        
        batch_size = obs.shape[0]
        next_obs_onehot = torch.zeros_like(obs)
        next_obs_onehot.scatter_(1, next_obs_idx.unsqueeze(1), 1)
        
        return next_obs_onehot
    
    def compute_loss(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> torch.Tensor:
        """
        Compute cross-entropy loss for dynamics prediction.
        
        Args:
            obs: Current one-hot observation
            action: Discrete action
            next_obs: Target one-hot next observation
            
        Returns:
            loss: Cross-entropy loss
        """
        next_obs_logits = self.forward(obs, action)
        next_obs_targets = torch.argmax(next_obs, dim=1)
        loss = F.cross_entropy(next_obs_logits, next_obs_targets)
        return loss


class WorldModelAgent:
    """
    Agent that uses RL (Reinforcement Learning) 
    to learn a representation space in the update function.
    """

    def __init__(self, **kwargs: tp.Any):
        cfg = WorldModelConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]

        self.cfg.rl_config.q_type = "full"
        self.cfg.rl_config.action_shape = cfg.action_shape

        # RL agent using encoder_rl
        self.rl_agent = DiscreteMBRLAgent(cfg.obs_shape[0], self.cfg.rl_config)
        self.rl_agent.load_encoder(nn.Identity())

        # Dynamics model for world model learning
        self.dynamics_model = OneHotDynamicsModel(
            obs_dim=cfg.obs_shape[0], 
            action_dim=self.action_dim, 
            hidden_dim=cfg.hidden_dim
        ).to(cfg.device)

        self.rl_agent.load_dynamics(self.dynamics_model)
        
        # Optimizer for dynamics model
        self.dynamics_optimizer = torch.optim.Adam(
            self.dynamics_model.parameters(), 
            lr=cfg.lr
        )

        # Feature network for fb_reward (optional)
        self.feature_net: tp.Optional[nn.Module] = None
        if self.cfg.fb_reward:
            self.feature_net = nn.Identity()  # Use Identity as feature net
            self.feature_net.eval()

        self.train()

    def train(self, training: bool = True) -> None:
        """Set training mode for the agent."""
        self.training = training
        self.rl_agent.train(training)
        self.dynamics_model.train(training)

    # def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
    #     """Get meta information for a specific goal."""
    #     meta = OrderedDict()
    #     meta['g'] = goal_array
    #     return meta

    def act(self, obs, meta, step, eval_mode) -> np.ndarray:
        """Select action using RL agent (encoder_rl)."""
        return self.rl_agent.act(obs, meta, step, eval_mode)

    def update_agent(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> tp.Dict[str, float]:
        """Train the dynamics model on the given batch of data."""
        metrics = {}
        
        # Ensure tensors are on the right device
        obs = obs.to(self.cfg.device)
        action = action.to(self.cfg.device)
        next_obs = next_obs.to(self.cfg.device)
        
        # Train dynamics model
        self.dynamics_optimizer.zero_grad()
        
        # Compute dynamics loss
        dynamics_loss = self.dynamics_model.compute_loss(obs, action, next_obs)
        
        # Backward pass
        dynamics_loss.backward()
        self.dynamics_optimizer.step()
        
        # Compute accuracy metrics
        with torch.no_grad():
            predicted_next_obs = self.dynamics_model.predict_deterministic(obs, action)
            next_obs_targets = torch.argmax(next_obs, dim=1)
            predicted_targets = torch.argmax(predicted_next_obs, dim=1)
            accuracy = (predicted_targets == next_obs_targets).float().mean()
        
        metrics['dynamics_loss'] = dynamics_loss.item()
        metrics['dynamics_accuracy'] = accuracy.item()
        
        return metrics

    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        """Main update method: Train the dynamics model using supervised learning."""
        metrics: tp.Dict[str, float] = {}
        
        # Sample batch from replay buffer
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        
        # Extract data from batch
        obs = batch.obs
        action = batch.action
        next_obs = batch.next_obs
        
        # Update dynamics model
        dynamics_metrics = self.update_agent(obs, action, next_obs)
        metrics.update(dynamics_metrics)
        
        # Log additional metrics
        if step % 1000 == 0:
            logger.info(f"Step {step}: Dynamics Loss = {dynamics_metrics.get('dynamics_loss', 0):.4f}, "
                       f"Accuracy = {dynamics_metrics.get('dynamics_accuracy', 0):.4f}")

        return metrics

    def predict_next_observation(self, obs: torch.Tensor, action: torch.Tensor, deterministic: bool = True) -> torch.Tensor:
        """
        Predict next observation using the learned dynamics model.
        
        Args:
            obs: Current observation (one-hot)
            action: Action to take (discrete or one-hot)
            deterministic: Whether to use deterministic (argmax) or probabilistic prediction
            
        Returns:
            Predicted next observation
        """
        self.dynamics_model.eval()
        with torch.no_grad():
            if deterministic:
                return self.dynamics_model.predict_deterministic(obs, action)
            else:
                return self.dynamics_model.predict(obs, action)

    def rollout(self, initial_obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """
        Perform a rollout using the learned dynamics model.
        
        Args:
            initial_obs: Starting observation (batch_size, obs_dim)
            actions: Sequence of actions (batch_size, sequence_length)
            
        Returns:
            Sequence of predicted observations (batch_size, sequence_length + 1, obs_dim)
        """
        self.dynamics_model.eval()
        batch_size, sequence_length = actions.shape
        obs_dim = initial_obs.shape[1]
        
        # Store all observations in the rollout
        observations = torch.zeros(batch_size, sequence_length + 1, obs_dim, device=initial_obs.device)
        observations[:, 0] = initial_obs
        
        current_obs = initial_obs
        with torch.no_grad():
            for t in range(sequence_length):
                action = actions[:, t]
                next_obs = self.dynamics_model.predict_deterministic(current_obs, action)
                observations[:, t + 1] = next_obs
                current_obs = next_obs
        
        return observations

    def inference(self, replay_loader: ReplayBuffer, infer_logger, q_pos, q_neg, reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        """Run RL agent update for reward inference/learning, separate from ACRO update."""
        # metrics = {}
        self.rl_agent.load_encoder(nn.Identity())  # Ensure RL agent uses latest ACRO encoder
        self.rl_agent.load_dynamics(self.dynamics_model)  # Load the learned dynamics model
        for step in range(self.cfg.rl_training_steps):
            rl_metrics = self.rl_agent.update(replay_loader, step, reward_fn)
            if step % 1000 == 0:
                print(f"RL training step {step}, metrics: {rl_metrics}")
            # infer_logger.log_metrics(rl_metrics, step)

    def q_function_inference(self, obs: torch.Tensor, q_pos: torch.Tensor, q_neg: torch.Tensor, z: tp.Dict[str, float]) -> torch.Tensor:
        """Get Q-values from RL agent."""
        return self.rl_agent.q_function(obs)

    # Optional methods that some agents implement
    def infer_w_goal(self, replay_loader: ReplayBuffer, goal: np.ndarray) -> tp.Dict[str, float]:
        """Infer goal representation (optional method)."""
        # This is a placeholder - implement if needed
        return {}

    def distill_actor_ddpg(self, replay_loader: ReplayBuffer, logger, goal: np.ndarray) -> tp.Dict[str, float]:
        """Distill actor using DDPG (optional method)."""
        # This is a placeholder - implement if needed
        return {} 