# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
from collections import OrderedDict
import dataclasses
import logging
import random

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

from url_benchmark.dmc import TimeStep
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark import utils
from .fb_modules import mlp
from .encoder_rl import DiscreteRLAgent, DiscreteRLAgentConfig
from url_benchmark import goals as _goals

logger = logging.getLogger(__name__)
MetaDict = tp.Mapping[str, np.ndarray]


@dataclasses.dataclass
class PVFAgentConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.pvf.PVFAgent"
    name: str = "pvf_agent"
    reward_free: bool = omegaconf.II("reward_free")
    custom_reward: tp.Optional[str] = omegaconf.II("custom_reward")
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    critic_target_tau: float = 0.01
    update_every_steps: float = 2
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING
    hidden_dim: int = 1024
    feature_dim: int = 50
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)"
    stddev_clip: float = 0.3  # 1.0
    nstep: int = 1
    batch_size: int = 1024  # 256 for pixels
    init_critic: bool = True
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    fb_reward: bool = False
    future_ratio: float = 0
    preprocess: bool = False
    add_trunk: bool = False
    supervised: bool = True
    
    # PVF-specific parameters
    ortho_coef: float = 1.0
    num_inference_steps: int = 25000

    rl_config = DiscreteRLAgentConfig()

cs = ConfigStore.instance()
cs.store(group="agent", name="pvf_agent", node=PVFAgentConfig)

# PVF encoder
class PVFEncoder(nn.Module):
    """PVF encoder using proto-value functions (eigenvectors of Laplacian)."""
    def __init__(self, obs_dim: int, hidden_dim, embed_dim: int):
        super().__init__()
        self.encoder = mlp(obs_dim, hidden_dim, "ntanh", hidden_dim, "irelu", embed_dim)
        self.output_dim = embed_dim
        self.apply(utils.weight_init)
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.encoder(obs)


class PVFAgent:
    """
    Agent that uses PVF (Proto-Value Functions) 
    to learn a representation space in the update function.
    """

    def __init__(self, **kwargs: tp.Any):
        print('initializing PVF agent')
        cfg = PVFAgentConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.solved_meta: tp.Any = None

        self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")


        # PVF encoder
        self.pvf_encoder = PVFEncoder(self.obs_dim, cfg.hidden_dim, cfg.feature_dim).to(cfg.device)
        self.pvf_target_encoder = PVFEncoder(self.obs_dim, cfg.hidden_dim, cfg.feature_dim).to(cfg.device)
        self.pvf_target_encoder.load_state_dict(self.pvf_encoder.state_dict())

        # self.pvf_encoder.to(cfg.device)
        # self.pvf_target_encoder.to(cfg.device)

        # for param in self.pvf_target_encoder.parameters():
        #     param.requires_grad = False
        for param in self.pvf_encoder.parameters():
            param.requires_grad = True

        self.pvf_optimizer = torch.optim.Adam(self.pvf_encoder.parameters(), lr=cfg.lr)

        # RL agent using encoder_rl
        self.cfg.rl_config.q_type = "linear"
        self.cfg.rl_config.action_shape = cfg.action_shape
        self.rl_agent = DiscreteRLAgent(self.cfg.feature_dim, self.cfg.rl_config)
        self.rl_agent.load_encoder(self.pvf_encoder)

        # Feature network for fb_reward (optional)
        # self.feature_net = None
        # if self.cfg.fb_reward:
        #     self.feature_net = self.pvf_encoder  # Use PVF encoder as feature net
        #     self.feature_net.eval()

        self.train()
        self.pvf_encoder.train()

    def train(self, training: bool = True) -> None:
        """Set training mode for the agent."""
        self.training = training
        self.rl_agent.train(training)
        self.pvf_encoder.train(training)


    def update_pvf(self, obs: torch.Tensor, next_obs: torch.Tensor, discount: torch.Tensor) -> tp.Dict[str, float]:
        """Update PVF encoder using temporal difference, orthonormality, and (optional) Q loss."""
        # Backward embedding TD loss
        pvf_loss = 0.0
        self.pvf_encoder = self.pvf_encoder.to(self.cfg.device)
        self.pvf_target_encoder = self.pvf_target_encoder.to(self.cfg.device)
        for param in self.pvf_encoder.parameters():
            param.requires_grad = True
        b_obs = self.pvf_encoder(obs)
        b_next_obs = self.pvf_target_encoder(next_obs)
        b_loss = ((b_obs - b_next_obs) ** 2).mean()
        pvf_loss += b_loss


        # Orthonormality loss for backward embedding
        B = b_obs
        Cov = torch.matmul(B, B.T)
        off_diag = ~torch.eye(B.shape[0], dtype=bool, device=B.device)
        orth_loss_diag = -2 * Cov.diag().mean()
        orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        orth_loss = orth_loss_offdiag + orth_loss_diag
        # import pdb; pdb.set_trace()
        pvf_loss += self.cfg.ortho_coef * orth_loss

        # import pdb; pdb.set_trace() 
        self.pvf_optimizer.zero_grad(set_to_none=True)
        pvf_loss.backward()
        # Log gradients of PVF encoder
        grad = 0.0
        for param in self.pvf_encoder.parameters():
            if param.grad is not None:
                grad += param.grad.norm().item()
        self.pvf_optimizer.step()


        # Soft update target net
        utils.soft_update_params(self.pvf_encoder, self.pvf_target_encoder, self.cfg.critic_target_tau)

        return {
            "pvf_b_loss": b_loss.item(),
            "pvf_orth_loss": orth_loss.item(),
            "pvf_total_loss": pvf_loss.item(),
            "pvf_grad": grad,
        }


    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        """Main update method: Only update PVF features during training. RL agent update is not called here."""
        metrics: tp.Dict[str, float] = {}

        # Only update PVF representation learning
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        obs = batch.obs
        next_obs = batch.next_obs
        discount = batch.discount
        metrics.update(self.update_pvf(obs, next_obs, discount))
        # if step % 500 == 0:
        #     print(f"PVF training step {step}, metrics: {metrics}")
        #     print(self.pvf_encoder(obs))
        return metrics

    def inference(self, replay_loader: ReplayBuffer, infer_logger, q_pos, q_neg, reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        """Run RL agent update for reward inference/learning, separate from PVF update."""
        # metrics = {}
        self.rl_agent.init_networks()
        self.rl_agent.load_encoder(self.pvf_encoder)  # Ensure RL agent uses latest PVF encoder
        for step in range(self.cfg.num_inference_steps):
            rl_metrics = self.rl_agent.update(replay_loader, step, reward_fn)
            infer_logger.log_metrics(rl_metrics)
        
        return {}
    
    def q_function_inference(self, obs: torch.Tensor, q_pos: torch.Tensor, q_neg: torch.Tensor, z: tp.Dict[str, float]) -> torch.Tensor:
        """Get Q-values from RL agent."""
        return self.rl_agent.q_function(obs)

    # Optional methods that some agents implement
    def infer_w_goal(self, replay_loader: ReplayBuffer, goal: np.ndarray) -> tp.Dict[str, float]:
        """Infer goal representation (optional method)."""
        # This is a placeholder - implement if needed
        return {}

    def distill_actor_ddpg(self, replay_loader: ReplayBuffer, logger, goal: np.ndarray) -> tp.Dict[str, float]:
        """Distill actor using DDPG (optional method)."""
        # This is a placeholder - implement if needed
        return {} 