# maml_rl/anil_trpo.py
from copy import deepcopy
from typing import Dict, Tuple, Optional
import weakref
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.kl import kl_divergence
import numpy as np
from .anil_networks import PolicyBackbone, GaussianPolicyHead, ValueBackbone, ValueHead


class _PolicyAlias(nn.Module):
    """
    정책 부분만 래핑:
      - parameters()/named_parameters() => policy_backbone + policy_head 만 노출
      - anil_update_params => 헤드(및 옵션의 value head)만 노출 (inner loop용)
      - state_dict()/load_state_dict() => 정책 파트만 저장/로드
    부모(TRPO)와의 순환참조를 피하기 위해 weakref 사용.
    """
    def __init__(self, parent: "TRPO"):
        super().__init__()
        # nn.Module의 서브모듈 등록을 피하기 위해 weakref로 저장
        object.__setattr__(self, "_parent_ref", weakref.ref(parent))

    # 내부 편의
    def _parent(self) -> "TRPO":
        p = self._parent_ref()
        if p is None:
            raise RuntimeError("Parent TRPO has been garbage collected.")
        return p

    # -------- 파라미터 노출(정책만) --------
    def parameters(self, recurse: bool = True):
        p = self._parent()
        yield from p.policy_backbone.parameters(recurse=recurse)
        yield from p.policy_head.parameters(recurse=recurse)

    def named_parameters(self, prefix: str = "", recurse: bool = True):
        p = self._parent()
        for n, par in p.policy_backbone.named_parameters(prefix="policy_backbone", recurse=recurse):
            yield n, par
        for n, par in p.policy_head.named_parameters(prefix="policy_head", recurse=recurse):
            yield n, par

    # -------- ANIL 내루프 대상(헤드만) --------
    @property
    def anil_update_params(self):
        p = self._parent()
        params = list(p.policy_head.parameters())
        if getattr(p, "adapt_value_head", False):
            params += list(p.value_head.parameters())
        return params

    # -------- 레거시 인터페이스 위임 --------
    def forward(self, obs: torch.Tensor):
        return self._parent().forward(obs)

    def distribution(self, obs: torch.Tensor):
        return self._parent().distribution(obs)

    @property
    def is_deterministic(self) -> bool:
        return self._parent().is_deterministic

    @is_deterministic.setter
    def is_deterministic(self, flag: bool):
        self._parent().is_deterministic = flag

    # -------- 직렬화: 정책 파트만 저장/로드 --------
    def state_dict(self, destination=None, prefix:str="", keep_vars:bool=False):
        """
        policy_backbone.*, policy_head.* 만 포함한 state_dict 반환
        (부모 전체를 순회하지 않음 → 순환참조 방지)
        """
        p = self._parent()
        od = OrderedDict()
        for k, v in p.policy_backbone.state_dict(keep_vars=keep_vars).items():
            od[f"{prefix}policy_backbone.{k}"] = v
        for k, v in p.policy_head.state_dict(keep_vars=keep_vars).items():
            od[f"{prefix}policy_head.{k}"] = v
        return od

    def load_state_dict(self, state_dict, strict: bool = True):
        """
        저장했던 포맷을 그대로 부모의 policy_backbone/head에 로드
        """
        p = self._parent()
        # 분리
        bb = {k.split("policy_backbone.",1)[1]: v
              for k, v in state_dict.items() if k.startswith("policy_backbone.")}
        hd = {k.split("policy_head.",1)[1]: v
              for k, v in state_dict.items() if k.startswith("policy_head.")}
        # 로드
        missing, unexpected = [], []
        res_bb = p.policy_backbone.load_state_dict(bb, strict=strict)
        res_hd = p.policy_head.load_state_dict(hd, strict=strict)
        # PyTorch 버전에 따라 IncompatibleKeys를 반환할 수도 있지만, 여기서는 None으로 처리
        return None

class FrozenPolicy(nn.Module):
    """
    Old-policy snapshot for ratios/KL, WITHOUT self-references to the live agent.
    Holds only policy backbone/head (no value net), so name-matched copies work.
    """
    def __init__(self, obs_dim: int, policy_hidden_dim: int, action_dim: int,
                 deterministic: bool = False, device: Optional[torch.device] = None):
        super().__init__()
        self.policy_backbone = PolicyBackbone(obs_dim, policy_hidden_dim)
        self.policy_head = GaussianPolicyHead(policy_hidden_dim, action_dim, deterministic=deterministic)
        if device is not None:
            self.to(device)

    def forward(self, obs: torch.Tensor):
        feats = self.policy_backbone(obs)
        dist, _ = self.policy_head.dist(feats)
        return dist

    def distribution(self, obs: torch.Tensor):
        return self.forward(obs)

    def get_log_prob(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        feats = self.policy_backbone(obs)
        return self.policy_head.log_prob(feats, action)


class TRPO(nn.Module):
    """
    ANIL-style TRPO agent:
      - Inner loop: update ONLY heads (policy_head [+ optional value_head])
      - Outer loop: meta-update backbone + head via query loss
    Provides legacy attributes/methods used by MetaLearner (policy alias, KL, entropy, HVP/CG, etc.)
    """
    def __init__(
        self,
        observ_dim: int,
        action_dim: int,
        policy_hidden_dim: int,
        vf_hidden_dim: int,
        device: torch.device,
        # PG / RL params
        gamma: float = 0.99,
        lamda: float = 1.0,
        vf_learning_rate: float = 0.1,
        vf_learning_iters: int = 1,
        # ANIL toggles
        adapt_value_head: bool = False,
        deterministic_policy: bool = False,
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.gamma = gamma
        self.lamda = lamda
        self.vf_lr = vf_learning_rate
        self.vf_iters = vf_learning_iters
        self.adapt_value_head = adapt_value_head

        # Keep dims for snapshot construction
        self._obs_dim = observ_dim
        self._act_dim = action_dim
        self._p_hidden = policy_hidden_dim
        self._det_default = deterministic_policy

        # Build ANIL components
        self.policy_backbone = PolicyBackbone(observ_dim, policy_hidden_dim).to(device)
        self.policy_head = GaussianPolicyHead(policy_hidden_dim, action_dim, deterministic=deterministic_policy).to(device)

        self.value_backbone = ValueBackbone(observ_dim, vf_hidden_dim).to(device)
        self.value_head = ValueHead(vf_hidden_dim).to(device)

        # Inner-loop updatable subset (ANIL)
        self.anil_update_params = list(self.policy_head.parameters())
        if self.adapt_value_head:
            self.anil_update_params += list(self.value_head.parameters())

        # Legacy alias so code can do DifferentiableSGD(self.agent.policy, ...)
        self.policy = _PolicyAlias(self)

        # Old policy snapshot (NO deepcopy(self) to avoid recursion)
        self._old_policy: Optional[FrozenPolicy] = None
        self.snapshot_policy()  # initialize once

    # ---------------- Legacy-friendly interface ----------------

    def forward(self, obs: torch.Tensor):
        feats = self.policy_backbone(obs)
        dist, _ = self.policy_head.dist(feats)
        return dist

    def distribution(self, obs: torch.Tensor):
        return self.forward(obs)

    # property used like: agent.policy.is_deterministic = True/False
    @property
    def is_deterministic(self) -> bool:
        return bool(self.policy_head.deterministic)

    @is_deterministic.setter
    def is_deterministic(self, flag: bool):
        self.policy_head.deterministic = bool(flag)

    # ---------------- Policy/Value basic APIs ----------------

    def act(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        feats = self.policy_backbone(obs)
        action, logp = self.policy_head.sample(feats)
        return action, logp

    def get_action(self, obs, return_logp: bool = False):
        """
        Sampler가 호출하는 인터페이스. obs는 np.ndarray 또는 torch.Tensor 모두 허용.
        반환: env.step에 바로 넣을 수 있는 numpy action (또는 (action, logp))
        """
        if isinstance(obs, np.ndarray):
            obs_t = torch.from_numpy(obs).float().to(self.device)
        else:
            # torch.Tensor 가정
            obs_t = obs.to(self.device).float()

        if obs_t.dim() == 1:
            obs_t = obs_t.unsqueeze(0)  # [obs_dim] -> [1, obs_dim]

        with torch.no_grad():
            action_t, logp_t = self.act(obs_t)  # [1, act_dim], [1, 1]

        action = action_t.squeeze(0).cpu().numpy()
        if return_logp:
            return action, float(logp_t.squeeze(0).cpu().item())
        return action

    def get_log_prob(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        feats = self.policy_backbone(obs)
        return self.policy_head.log_prob(feats, action)

    def value(self, obs: torch.Tensor) -> torch.Tensor:
        feats = self.value_backbone(obs)
        return self.value_head(feats)

    # ---------------- Snapshot old policy ----------------

    def snapshot_policy(self):
        """
        Rebuild a frozen, detached policy snapshot WITHOUT deep-copying self.
        Copies only policy backbone/head weights by name.
        """
        frozen = FrozenPolicy(
            obs_dim=self._obs_dim,
            policy_hidden_dim=self._p_hidden,
            action_dim=self._act_dim,
            deterministic=self.is_deterministic,
            device=self.device,
        )
        # Copy policy weights
        frozen.policy_backbone.load_state_dict(self.policy_backbone.state_dict())
        frozen.policy_head.load_state_dict(self.policy_head.state_dict())
        frozen.eval()
        for p in frozen.parameters():
            p.requires_grad_(False)
        self._old_policy = frozen

    @property
    def old_policy(self) -> Optional[FrozenPolicy]:
        return self._old_policy
    
    def _get_batch_tensor(self, batch: Dict[str, torch.Tensor], keys, default=None) -> torch.Tensor:
        """batch에서 후보 key 순서대로 찾아 tensor로 반환. 없으면 default 사용."""
        for k in keys:
            if k in batch:
                x = batch[k]
                return torch.as_tensor(x, dtype=torch.float32, device=self.device)
        if default is not None:
            return torch.as_tensor(default, dtype=torch.float32, device=self.device)
        raise KeyError(f"Expected one of keys {keys} in batch, but none found. Available: {list(batch.keys())}")


    def infer_baselines(self, batch: Dict[str, torch.Tensor]):
        """
        Buffer가 len(self._baselines) == T 이고, 각 원소가 실수(float)인 리스트를 기대함.
        -> 'baseline'을 V(s)로 정의해 스텝별 value 예측을 float 리스트로 반환.
        """
        # 다양한 키 이름 지원
        obs = self._get_batch_tensor(batch, ("obs", "obs", "states", "state", "ob"))

        with torch.no_grad():
            values = self.value(obs).view(-1)  # [T]

        # buffer가 torch.Tensor(value)로 감쌀 수 있도록 순수 float 리스트로 반환
        return [float(v.item()) for v in values]

    # ---------------- Losses ----------------

    def compute_gae(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        # 관측/보상/종료키를 유연하게 해석
        obs = self._get_batch_tensor(batch, ("obs", "obs", "states", "state", "ob"))
        rewards = self._get_batch_tensor(batch, ("rewards", "reward", "rews")).view(-1, 1)
        # dones/terminals가 없다면 0으로 대체
        dones = self._get_batch_tensor(batch, ("dones", "done", "terminals", "terminal"),
                                    default=torch.zeros_like(rewards)).view(-1, 1)

        with torch.no_grad():
            values = self.value(obs).view(-1, 1)

        T = rewards.shape[0]
        adv = torch.zeros_like(rewards, device=self.device)
        last_gae = 0.0
        for t in reversed(range(T)):
            mask = 1.0 - dones[t]
            next_v = values[t + 1] if t + 1 < T else torch.zeros_like(values[0])
            delta = rewards[t] + self.gamma * next_v * mask - values[t]
            last_gae = delta + self.gamma * self.lamda * mask * last_gae
            adv[t] = last_gae
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)
        return adv


    def policy_loss(self, batch: Dict[str, torch.Tensor], is_meta_loss: bool = False) -> torch.Tensor:
        """
        - is_meta_loss=False (inner loop): 백본 고정(detach), 헤드만 적응
        - is_meta_loss=True  (outer/meta): old/new ratio로 서러게이트
        """
        # 다양한 키 이름 지원 (_get_batch_tensor가 이미 클래스에 추가되어 있다고 가정)
        obs = self._get_batch_tensor(batch, ("observations", "obs", "states", "state", "ob"))
        actions = self._get_batch_tensor(batch, ("actions", "action", "acts", "act"))

        # 어드밴티지는 그래프에서 분리
        adv = self.compute_gae(batch).detach()

        if is_meta_loss and self.old_policy is not None:
            # META LOSS (ratio)
            with torch.no_grad():
                old_logp = self.old_policy.get_log_prob(obs, actions).view(-1, 1)
            new_logp = self.get_log_prob(obs, actions).view(-1, 1)
            ratio = (new_logp - old_logp).exp()
            loss = -(ratio * adv).mean()
        else:
            # INNER LOOP (ANIL): 백본 피처는 detach, 헤드만 grad
            with torch.no_grad():
                feats = self.policy_backbone(obs)
            feats = feats.detach()
            logp = self.policy_head.log_prob(feats, actions).view(-1, 1)
            loss = -(logp * adv).mean()

        return loss


    def value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        obs = batch["obs"]
        rewards = batch["rewards"].float().view(-1, 1)
        dones = batch["dones"].float().view(-1, 1)
        with torch.no_grad():
            v = self.value(obs).view(-1, 1)
            next_v = torch.cat([v[1:], torch.zeros_like(v[:1])], dim=0)
            target = rewards + self.gamma * next_v * (1.0 - dones)
        pred = self.value(obs).view(-1, 1)
        return F.mse_loss(pred, target)

    # ---------------- Metrics for MetaLearner ----------------

    def kl_divergence(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        obs = self._get_batch_tensor(batch, ("obs", "obs", "states", "state", "ob"))
        with torch.no_grad():
            old_dist = (self.old_policy or self).distribution(obs)
        new_dist = self.distribution(obs)
        kl = kl_divergence(old_dist, new_dist)
        if kl.dim() > 1:
            kl = kl.sum(dim=-1)
        return kl.mean()

    def compute_policy_entropy(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        obs = self._get_batch_tensor(batch, ("obs", "obs", "states", "state", "ob"))
        ent = self.distribution(obs).entropy()
        if ent.dim() > 1:
            ent = ent.sum(dim=-1)
        return ent.mean()

    # ---------------- TRPO math helpers (used in meta_update) ----------------

    @staticmethod
    def flat_grad(grads) -> torch.Tensor:
        parts = []
        for g in grads:
            if g is None:
                continue
            parts.append(g.contiguous().view(-1))
        return torch.cat(parts) if parts else torch.tensor(0.0)

    def _unflatten_to_like(self, flat: torch.Tensor, params_list):
        views = []
        offset = 0
        for p in params_list:
            n = p.numel()
            views.append(flat[offset:offset+n].view_as(p))
            offset += n
        return views

    def hessian_vector_product(self, kl_scalar: torch.Tensor, params_iter):
        params_list = list(params_iter)
        def _hvp(v_flat: torch.Tensor) -> torch.Tensor:
            grads = torch.autograd.grad(kl_scalar, params_list, create_graph=True, retain_graph=True)
            flat_grads = self.flat_grad(grads)
            gv = (flat_grads * v_flat).sum()
            hv = torch.autograd.grad(gv, params_list, retain_graph=True)
            hv_flat = self.flat_grad(hv)
            return hv_flat
        return _hvp

    def conjugate_gradient(self, Hv, b: torch.Tensor, nsteps: int = 10, residual_tol: float = 1e-10):
        x = torch.zeros_like(b)
        r = b.clone()
        p = r.clone()
        rdotr = torch.dot(r, r)
        for _ in range(nsteps):
            Hp = Hv(p)
            alpha = rdotr / (torch.dot(p, Hp) + 1e-8)
            x = x + alpha * p
            r = r - alpha * Hp
            new_rdotr = torch.dot(r, r)
            if new_rdotr < residual_tol:
                break
            beta = new_rdotr / (rdotr + 1e-8)
            p = r + beta * p
            rdotr = new_rdotr
        return x

    def compute_descent_step(self, Hv, search_dir_flat: torch.Tensor, max_kl: float):
        sHs = torch.dot(search_dir_flat, Hv(search_dir_flat))
        scale = torch.sqrt(2.0 * max_kl / (sHs + 1e-8))
        step_flat = scale * search_dir_flat
        params_list = list(self.policy.parameters())  # alias to self.parameters()
        return self._unflatten_to_like(step_flat, params_list)

    # ---------------- State copy helper (MetaLearner uses this) ----------------

    def update_model(self, target_module: nn.Module, named_params: dict):
        """
        Copy by name into target_module's parameters (e.g., self.policy or self.old_policy).
        Extra names are ignored; missing names are skipped.
        """
        target_dict = dict(target_module.named_parameters())
        for name, src in named_params.items():
            if name in target_dict:
                val = src.detach().clone() if isinstance(src, torch.Tensor) else src.data.detach().clone()
                target_dict[name].data.copy_(val)
