# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
from collections import OrderedDict
import dataclasses
import logging
import random

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

from url_benchmark.dmc import TimeStep
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark import utils
from .fb_modules import mlp
from .encoder_rl import DiscreteRLAgent, DiscreteRLAgentConfig
from url_benchmark import goals as _goals

logger = logging.getLogger(__name__)
MetaDict = tp.Mapping[str, np.ndarray]


@dataclasses.dataclass
class AcroAgentConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.discrete_acro.AcroAgent"
    name: str = "acro_agent"
    reward_free: bool = omegaconf.II("reward_free")
    custom_reward: tp.Optional[str] = omegaconf.II("custom_reward")
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    critic_target_tau: float = 0.01
    update_every_steps: float = 2
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING
    hidden_dim: int = 1024
    feature_dim: int = 50
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)"
    stddev_clip: float = 0.3  # 1.0
    nstep: int = 1
    batch_size: int = 1024  # 256 for pixels
    init_critic: bool = True
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    fb_reward: bool = False
    future_ratio: float = 0
    preprocess: bool = False
    add_trunk: bool = False
    supervised: bool = True
    
    # ACRO-specific parameters
    acro_learning_rate: float = 1e-4
    acro_weight_decay: float = 1e-5
    acro_forward_weight: float = 1.0
    acro_l2_penalty: float = 0.0
    acro_use_l2_norm: bool = False
    acro_l1_penalty: float = 0.0
    acro_dynamic_l1_penalty: bool = False
    acro_train_stop_epochs: int = 1000000
    acro_representation_train_steps: int = 1000
    acro_k_steps: int = 1
    acro_embed_dim: int = 50
    num_inference_steps: int = 25000

    rl_config = DiscreteRLAgentConfig()

cs = ConfigStore.instance()
cs.store(group="agent", name="acro_agent", node=AcroAgentConfig)


class AcroEncoder(nn.Module):
    """ACRO encoder network that learns representations from observations."""
    
    def __init__(self, obs_dim: int, embed_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.encoder = mlp(obs_dim, hidden_dim, "ntanh", hidden_dim, "irelu", embed_dim, "ntanh")
        self.output_dim = embed_dim
        self.apply(utils.weight_init)
    
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.encoder(obs)


class AcroForwardDynamics(nn.Module):
    """ACRO forward dynamics model that predicts next state representation."""
    
    def __init__(self, embed_dim: int, action_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.action_dim = action_dim    
        self.forward_model = mlp(embed_dim + action_dim, hidden_dim, "ntanh", hidden_dim, "irelu", embed_dim, "ntanh")
        self.apply(utils.weight_init)
    
    def forward(self, encoded_obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # For continuous actions, we use the raw action values
        # For discrete actions, we would need to one-hot encode them
        action = F.one_hot(action.long(), num_classes=self.action_dim).squeeze(1).float()
        # import pdb; pdb.set_trace()
        combined = torch.cat([encoded_obs, action], dim=-1)
        return self.forward_model(combined)


class AcroInverseDynamics(nn.Module):
    """ACRO inverse dynamics model that predicts action from state representations."""
    
    def __init__(self, embed_dim: int, action_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.inverse_model = mlp(2 * embed_dim, hidden_dim, "ntanh", hidden_dim, "irelu", action_dim)
        self.apply(utils.weight_init)
    
    def forward(self, encoded_obs: torch.Tensor, encoded_next_obs: torch.Tensor) -> torch.Tensor:
        combined = torch.cat([encoded_obs, encoded_next_obs], dim=-1)
        return self.inverse_model(combined)

class AcroAgent:
    """
    Agent that uses ACRO (Action-Conditioned Representation Optimization) 
    to learn a representation space in the update function.
    """

    def __init__(self, **kwargs: tp.Any):
        print('initializing ACRO agent')
        cfg = AcroAgentConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.solved_meta: tp.Any = None

        self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")

        # ACRO encoder
        self.acro_encoder = AcroEncoder(self.obs_dim, cfg.acro_embed_dim, cfg.hidden_dim).to(cfg.device)
        self.acro_forward = AcroForwardDynamics(cfg.acro_embed_dim, self.action_dim, cfg.hidden_dim).to(cfg.device)
        self.acro_inverse = AcroInverseDynamics(cfg.acro_embed_dim, self.action_dim, cfg.hidden_dim).to(cfg.device)

        # ACRO optimizer
        self.acro_optimizer = torch.optim.Adam(
            list(self.acro_encoder.parameters()) +
            list(self.acro_forward.parameters()) +
            list(self.acro_inverse.parameters()),
            lr=cfg.acro_learning_rate,
            weight_decay=cfg.acro_weight_decay,
        )
        self.cfg.rl_config.q_type = "mlp"
        self.cfg.rl_config.action_shape = cfg.action_shape
        # RL agent using encoder_rl
        self.rl_agent = DiscreteRLAgent(self.cfg.feature_dim, self.cfg.rl_config)
        self.rl_agent.load_encoder(self.acro_encoder)

        # Feature network for fb_reward (optional)
        self.feature_net: tp.Optional[nn.Module] = None
        if self.cfg.fb_reward:
            self.feature_net = self.acro_encoder  # Use ACRO encoder as feature net
            self.feature_net.eval()

        # Reward model for reward-free learning (optional)
        self.reward_model: tp.Optional[torch.nn.Module] = None
        self.reward_opt: tp.Optional[torch.optim.Adam] = None
        if cfg.reward_free:
            self.reward_model = mlp(self.obs_dim, cfg.hidden_dim, "ntanh", cfg.hidden_dim,
                                    "relu", cfg.hidden_dim, "relu", 1).to(cfg.device)
            self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), lr=1e-3)

        self.train()

    def train(self, training: bool = True) -> None:
        """Set training mode for the agent."""
        self.training = training
        self.rl_agent.train(training)
        self.acro_encoder.train(training)
        self.acro_forward.train(training)
        self.acro_inverse.train(training)

    def precompute_cov(self, replay_loader: ReplayBuffer) -> None:
        """Precompute covariance matrix for feature-based reward (used by training scripts)."""
        if not self.cfg.fb_reward:
            return None
        
        logger.info("computing Cov of phi to be used at inference")
        obs_list: tp.List[torch.Tensor] = []
        batch_size = 0
        while batch_size < 100000:
            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)
            obs_list.append(batch.next_goal if self.cfg.goal_space is not None else batch.next_obs)
            batch_size += batch.next_obs.size(0)
        obs = torch.cat(obs_list, 0)

        with torch.no_grad():
            phi = self.acro_encoder(obs)  # Use ACRO encoder
        cov = torch.matmul(phi.T, phi) / phi.shape[0]
        self.inv_cov = torch.linalg.pinv(cov)

    def init_from(self, other) -> None:
        """Copy parameters from another agent."""
        utils.hard_update_params(other.actor, self.actor)
        utils.hard_update_params(other.critic, self.critic)

    def init_meta(self, custom_reward: tp.Optional[_goals.BaseReward] = None) -> MetaDict:
        """Initialize meta information for the agent."""
        if isinstance(custom_reward, _goals.MazeMultiGoal):
            idx = np.random.choice(len(custom_reward.goals))
            desired_goal = custom_reward.goals[idx]
            meta = OrderedDict()
            meta["g"] = desired_goal
            return meta
        else:
            return OrderedDict()

    def update_meta(
        self,
        meta: MetaDict,
        global_step: int,
        time_step: TimeStep,
        finetune: bool = False,
        replay_loader: tp.Optional[ReplayBuffer] = None
    ) -> MetaDict:
        """Update meta information (can be overridden by subclasses)."""
        return meta

    def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
        """Get meta information for a specific goal."""
        meta = OrderedDict()
        meta['g'] = goal_array
        return meta

    def act(self, obs, meta, step, eval_mode) -> np.ndarray:
        """Select action using RL agent (encoder_rl)."""
        return self.rl_agent.act(obs, meta, step, eval_mode)

    def train_reward(self, replay_loader: ReplayBuffer) -> None:
        """Train reward model for reward-free learning."""
        obs_list, reward_list = [], []
        batch_size = 0
        num_inference_steps = 10000
        while batch_size < num_inference_steps:
            batch = replay_loader.sample(self.cfg.batch_size)
            obs, action, reward, discount, next_obs = batch.to(self.cfg.device).unpack()
            with torch.no_grad():
                obs, next_obs = self.acro_encoder(obs), self.acro_encoder(next_obs)
            del obs, action, discount
            obs_list.append(next_obs)
            reward_list.append(reward)
            batch_size += next_obs.size(0)
        
        obs, reward = torch.cat(obs_list, 0), torch.cat(reward_list, 0)
        obs, reward = obs[: num_inference_steps], reward[: num_inference_steps]
        
        print('max reward: ', reward.max().cpu().item())
        print('99 percentile: ', torch.quantile(reward, 0.99).cpu().item())
        print('median reward: ', reward.median().cpu().item())
        print('min reward: ', reward.min().cpu().item())
        print('mean reward: ', reward.mean().cpu().item())
        print('num reward: ', reward.shape[0])
        
        assert self.reward_model is not None
        for i in range(2000):
            reward_loss = (self.reward_model(obs) - reward).pow(2).mean()
            assert self.reward_opt is not None
            self.reward_opt.zero_grad(set_to_none=True)
            reward_loss.backward()
            self.reward_opt.step()
            print(f"iteration: {i}, reward_loss: {reward_loss.item()}")

        # compute test loss
        while batch_size < num_inference_steps:
            batch = replay_loader.sample(self.cfg.batch_size)
            obs, action, reward, discount, next_obs = batch.to(self.cfg.device).unpack()
            del obs, action, discount
            obs_list.append(next_obs)
            reward_list.append(reward)
            batch_size += next_obs.size(0)
        obs, reward = torch.cat(obs_list, 0), torch.cat(reward_list, 0)
        obs, reward = obs[: num_inference_steps], reward[: num_inference_steps]
        test_loss = (self.reward_model(obs) - reward).pow(2).mean()
        print(f"Test Loss: {test_loss.item()}")

    def update_acro(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> tp.Dict[str, float]:
        """Update ACRO representation learning components."""
        # Encode observations
        o_encoded = self.acro_encoder(obs)
        on_encoded = self.acro_encoder(next_obs)

        # Forward dynamics loss
        if self.cfg.acro_forward_weight > 0:
            forward_model_loss = F.mse_loss(
                self.acro_forward(o_encoded, action),
                on_encoded,
            )
        else:
            forward_model_loss = torch.tensor(0.0, device=obs.device)

        # L1 regularization loss
        if self.cfg.acro_l1_penalty > 0 and not self.cfg.acro_use_l2_norm:
            l1_loss = (
                torch.linalg.vector_norm(o_encoded, ord=1, dim=1).mean()
                + torch.linalg.vector_norm(on_encoded, ord=1, dim=1).mean()
            ) / 2
        else:
            l1_loss = torch.zeros(1, device=obs.device)

        # Inverse dynamics loss
        inverse_model_pred = self.acro_inverse(o_encoded, on_encoded)
        inverse_model_loss = F.mse_loss(inverse_model_pred, action)  # Use MSE for continuous actions

        # Calculate accuracy for inverse model (for monitoring)
        if self.action_dim == 1:  # Continuous action
            accuracy = torch.mean((torch.abs(inverse_model_pred - action) < 0.1).float())
        else:  # Discrete action
            accuracy = torch.mean((torch.argmax(inverse_model_pred, dim=1) == action).float())

        # Dynamic L1 penalty
        if self.cfg.acro_dynamic_l1_penalty:
            gain = 5
            cur_l1_penalty = self.cfg.acro_l1_penalty * np.exp(-gain * (accuracy.detach().item() - 1) ** 2)
        else:
            cur_l1_penalty = self.cfg.acro_l1_penalty

        l1_loss = cur_l1_penalty * l1_loss
        forward_model_loss = self.cfg.acro_forward_weight * forward_model_loss

        # L2 regularization loss
        p = F.softmax(o_encoded, dim=1)
        l2_per = torch.linalg.norm(p, ord=2, dim=1)
        l2_loss = l2_per.mean() * self.cfg.acro_l2_penalty

        # Total ACRO loss
        if self.cfg.acro_use_l2_norm:
            total_acro_loss = forward_model_loss + l2_loss + inverse_model_loss
        else:
            total_acro_loss = forward_model_loss + l1_loss + inverse_model_loss

        # Update ACRO components
        self.acro_optimizer.zero_grad(set_to_none=True)
        total_acro_loss.backward()
        self.acro_optimizer.step()

        return {
            "acro_inverse_loss": inverse_model_loss.detach().item(),
            "acro_forward_loss": forward_model_loss.detach().item(),
            "acro_l1_loss": l1_loss.detach().item(),
            "acro_l2_loss": l2_loss.detach().item(),
            "acro_total_loss": total_acro_loss.detach().item(),
            "acro_accuracy": accuracy.detach().item(),
            "acro_cur_l1_penalty": cur_l1_penalty,
            "acro_mean_element_magnitude": torch.abs(o_encoded).float().mean().detach().item(),
            "acro_mean_representation_magnitude": torch.linalg.vector_norm(o_encoded, ord=1, dim=1).mean().detach().item(),
        }


    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        """Main update method: Only update ACRO features during training. RL agent update is not called here."""
        metrics: tp.Dict[str, float] = {}

        # Only update ACRO representation learning
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        obs = batch.obs
        action = batch.action
        next_obs = batch.next_obs
        acro_metrics = self.update_acro(obs, action, next_obs)
        metrics.update(acro_metrics)

        return metrics

    def inference(self, replay_loader: ReplayBuffer, infer_logger, q_pos, q_neg, reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        """Run RL agent update for reward inference/learning, separate from ACRO update."""
        # metrics = {}
        self.rl_agent.init_networks()
        self.rl_agent.load_encoder(self.acro_encoder)  # Ensure RL agent uses latest ACRO encoder
        for step in range(self.cfg.num_inference_steps):
            rl_metrics = self.rl_agent.update(replay_loader, step, reward_fn)
            infer_logger.log_metrics(rl_metrics)

    def q_function_inference(self, obs: torch.Tensor, q_pos: torch.Tensor, q_neg: torch.Tensor, z: tp.Dict[str, float]) -> torch.Tensor:
        """Get Q-values from RL agent."""
        return self.rl_agent.q_function(obs)

    # Optional methods that some agents implement
    def infer_w_goal(self, replay_loader: ReplayBuffer, goal: np.ndarray) -> tp.Dict[str, float]:
        """Infer goal representation (optional method)."""
        # This is a placeholder - implement if needed
        return {}

    def distill_actor_ddpg(self, replay_loader: ReplayBuffer, logger, goal: np.ndarray) -> tp.Dict[str, float]:
        """Distill actor using DDPG (optional method)."""
        # This is a placeholder - implement if needed
        return {} 