# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
from collections import OrderedDict
import dataclasses
import logging
import random
import os

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

# Limit CPU core usage
torch.set_num_threads(4)  # Limit PyTorch to 4 threads
# os.environ['OMP_NUM_THREADS'] = '4'  # Limit OpenMP threads
# os.environ['MKL_NUM_THREADS'] = '4'  # Limit Intel MKL threads
# os.environ['NUMEXPR_NUM_THREADS'] = '4'  # Limit NumExpr threads
# os.environ['OPENBLAS_NUM_THREADS'] = '4'  # Limit OpenBLAS threads

from url_benchmark.dmc import TimeStep
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark import utils
from .fb_modules import mlp
from .encoder_rl import DiscreteRLAgent, DiscreteRLAgentConfig
from url_benchmark import goals as _goals

logger = logging.getLogger(__name__)
MetaDict = tp.Mapping[str, np.ndarray]


@dataclasses.dataclass
class RLConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.rl.RLAgent"
    name: str = "rl_agent"
    reward_free: bool = omegaconf.II("reward_free")
    custom_reward: tp.Optional[str] = omegaconf.II("custom_reward")
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    critic_target_tau: float = 0.01
    update_every_steps: float = 2
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING
    hidden_dim: int = 1024
    feature_dim: int = 512
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)"
    stddev_clip: float = 0.3  # 1.0
    nstep: int = 1
    batch_size: int = 1024  # 256 for pixels
    init_critic: bool = True
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    fb_reward: bool = False
    future_ratio: float = 0
    preprocess: bool = False
    add_trunk: bool = False
    supervised: bool = True
    rl_training_steps: int = 25000
    
    # CPU usage control
    num_threads: int = 4
    num_workers: int = 0  # For data loading

    rl_config = DiscreteRLAgentConfig()

cs = ConfigStore.instance()
cs.store(group="agent", name="rl_agent", node=RLConfig)

class RLAgent:
    """
    Agent that uses RL (Reinforcement Learning) 
    to learn a representation space in the update function.
    """

    def __init__(self, **kwargs: tp.Any):
        cfg = RLConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]

        # Apply CPU usage controls
        torch.set_num_threads(cfg.num_threads)
        torch.set_num_interop_threads(cfg.num_threads)
        
        self.cfg.rl_config.q_type = "full"
        self.cfg.rl_config.action_shape = cfg.action_shape

        # RL agent using encoder_rl
        self.rl_agent = DiscreteRLAgent(cfg.obs_shape[0], self.cfg.rl_config)
        self.rl_agent.load_encoder(nn.Identity())

        # Feature network for fb_reward (optional)
        self.feature_net: tp.Optional[nn.Module] = None
        if self.cfg.fb_reward:
            self.feature_net = nn.Identity()  # Use Identity as feature net
            self.feature_net.eval()

        self.train()

    def train(self, training: bool = True) -> None:
        """Set training mode for the agent."""
        self.training = training
        self.rl_agent.train(training)
        
        # Ensure CPU usage controls are maintained
        # torch.set_num_threads(self.cfg.num_threads)
        # torch.set_num_interop_threads(self.cfg.num_threads)

    def check_cpu_settings(self) -> None:
        """Check and log current CPU usage settings."""
        logger.info(f"PyTorch threads: {torch.get_num_threads()}")
        logger.info(f"PyTorch interop threads: {torch.get_num_interop_threads()}")
        logger.info(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'not set')}")
        logger.info(f"MKL_NUM_THREADS: {os.environ.get('MKL_NUM_THREADS', 'not set')}")

    # def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
    #     """Get meta information for a specific goal."""
    #     meta = OrderedDict()
    #     meta['g'] = goal_array
    #     return meta

    def act(self, obs, meta, step, eval_mode) -> np.ndarray:
        """Select action using RL agent (encoder_rl)."""
        return self.rl_agent.act(obs, meta, step, eval_mode)

    def update_agent(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> tp.Dict[str, float]:
        # Agent training
        return {}

    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        """Main update method: Only update ACRO features during training. RL agent update is not called here."""
        return {}
        metrics: tp.Dict[str, float] = {}

        # Only update ACRO representation learning
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        obs = batch.obs
        action = batch.action
        next_obs = batch.next_obs
        acro_metrics = self.update_agent(obs, action, next_obs)
        metrics.update(acro_metrics)

        return metrics

    def inference(self, replay_loader: ReplayBuffer, infer_logger, q_pos, q_neg, reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        """Run RL agent update for reward inference/learning, separate from ACRO update."""
        # metrics = {}
        self.rl_agent.init_networks()
        self.rl_agent.load_encoder(nn.Identity())  # Ensure RL agent uses latest ACRO encoder
        for step in range(self.cfg.rl_training_steps):
            rl_metrics = self.rl_agent.update(replay_loader, step, reward_fn)
            if step % 1000 == 0:
                print(f"RL training step {step}, metrics: {rl_metrics}")
            # infer_logger.log_metrics(rl_metrics, step)

    def q_function_inference(self, obs: torch.Tensor, q_pos: torch.Tensor, q_neg: torch.Tensor, z: tp.Dict[str, float]) -> torch.Tensor:
        """Get Q-values from RL agent."""
        return self.rl_agent.q_function(obs)

    # Optional methods that some agents implement
    def infer_w_goal(self, replay_loader: ReplayBuffer, goal: np.ndarray) -> tp.Dict[str, float]:
        """Infer goal representation (optional method)."""
        # This is a placeholder - implement if needed
        return {}

    def distill_actor_ddpg(self, replay_loader: ReplayBuffer, logger, goal: np.ndarray) -> tp.Dict[str, float]:
        """Distill actor using DDPG (optional method)."""
        # This is a placeholder - implement if needed
        return {} 