from __future__ import annotations
from typing import Dict

import torch as th

from gym import spaces
import torch.nn.functional as F

from achievement_distillation.model.base import BaseModel
from achievement_distillation.impala_cnn import ImpalaCNN
from achievement_distillation.action_head import CategoricalActionHead
from achievement_distillation.mse_head import ScaledMSEHead
from achievement_distillation.torch_util import FanInInitReLULayer
from achievement_distillation.mlp import MLP
from achievement_distillation.model.ppo_ad import PPOADModel


class IHACModel(PPOADModel):
    def __init__(
        self,
        observation_space: spaces.Box,
        action_space: spaces.Discrete,
        hidsize: int,
        impala_kwargs: Dict = {},
        dense_init_norm_kwargs: Dict = {},
        action_head_kwargs: Dict = {},
        mse_head_kwargs: Dict = {},
        nhidlayer: int = 1,
        temperature: float = 0.1,
        use_memory: bool = True,
        gamma: float = 0.99,  # Discount factor for Q value calculation
        lambda_kl: float = 1.0,  # KL weight in V value calculation
        lr: float = 1e-4  # Learning rate for optimizer
    ):
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            hidsize=hidsize,
            impala_kwargs=impala_kwargs,
            dense_init_norm_kwargs=dense_init_norm_kwargs,
            action_head_kwargs=action_head_kwargs,
            mse_head_kwargs=mse_head_kwargs,
            nhidlayer=nhidlayer,
            temperature=temperature,
            use_memory=use_memory
        )

        # Model components
        obs_shape = getattr(self.observation_space, "shape")
        self.enc = ImpalaCNN(
            obs_shape,
            dense_init_norm_kwargs=dense_init_norm_kwargs,
            **impala_kwargs,
        )
        outsize = impala_kwargs["outsize"]
        self.linear = FanInInitReLULayer(
            outsize,
            hidsize,
            layer_type="linear",
            **dense_init_norm_kwargs,
        )
        self.hidsize = hidsize

        # Heads
        num_actions = getattr(self.action_space, "n")
        self.pi_head = CategoricalActionHead(
            insize=hidsize,
            num_actions=num_actions,
            **action_head_kwargs,
        )
        self.vf_head = ScaledMSEHead(
            insize=hidsize,
            outsize=1,
            **mse_head_kwargs,
        )
        self.v_meta_head = ScaledMSEHead(
            insize=hidsize,
            outsize=1,  # Assuming v_meta is also scalar output
            **mse_head_kwargs,
        )

        # Discount factor and KL weight
        self.gamma = gamma
        self.lambda_kl = lambda_kl

        # Optimizer
        self.optimizer = th.optim.Adam(self.parameters(), lr=lr)

    def compute_q_value(self, rewards, v_meta_next, pi_meta_next):
        """ Compute Q value based on the given rewards, next state value and π_meta distribution. """
        q_value = rewards + self.gamma * (pi_meta_next * v_meta_next).sum(dim=-1, keepdim=True)
        return q_value.clamp(-1.0, 1.0)  # Adjust clip bounds if needed

    def compute_v_meta(self, q_values, pi_logits, pi_meta_logits):
        """ Compute V_meta using Q values and KL divergence between π and π_meta. """
        # Compute softmax for π and π_meta
        pi = th.softmax(pi_logits, dim=-1)
        pi_meta = th.softmax(pi_meta_logits, dim=-1)

        # V_t(s) = max_π { sum(π(a|s) * Q_t(s, a)) - λ * KL(π || π_meta) }
        q_expectation = (pi * q_values).sum(dim=-1, keepdim=True)
        kl_div = th.sum(pi * (th.log(pi + 1e-10) - th.log(pi_meta + 1e-10)), dim=-1, keepdim=True)  # Use 1e-10 for more stability

        v_meta = q_expectation - self.lambda_kl * kl_div
        return v_meta

    def forward(self, obs: th.Tensor, **kwargs) -> Dict[str, th.Tensor]:
        # Extract optional inputs from kwargs (if provided)
        rewards = kwargs.get("rewards")
        v_meta_next = kwargs.get("v_meta_next")
        pi_meta_next = kwargs.get("pi_meta_next")
        pi_meta_logits = kwargs.get("pi_meta_logits")

        # Pass through encoder
        latents = self.encode(obs)

        # Pass through heads
        pi_latents = vf_latents = v_meta_latents = latents
        pi_logits = self.pi_head(pi_latents)
        vpreds = self.vf_head(vf_latents)

        # Compute Q values and V_meta if rewards and next state values are provided
        if rewards is not None and v_meta_next is not None and pi_meta_next is not None:
            q_values = self.compute_q_value(rewards, v_meta_next, pi_meta_next)
            v_meta = self.compute_v_meta(q_values, pi_logits, pi_meta_logits)
        else:
            v_meta = None  # Skip V_meta computation if inputs are not provided

        outputs = {
            "latents": latents,
            "pi_logits": pi_logits,
            "vpreds": vpreds,
        }
        if v_meta is not None:
            outputs["v_meta"] = v_meta  # Include v_meta in outputs if computed

        return outputs

    @th.no_grad()
    def act(self, obs: th.Tensor, **kwargs) -> Dict[str, th.Tensor]:
        # This method is used to sample actions and get action log probs during rollout
        assert not self.training

        # Forward pass through the model
        outputs = self.forward(obs, **kwargs)

        # Sample actions from the policy head (pi_logits)
        pi_logits = outputs["pi_logits"]
        actions = self.pi_head.sample(pi_logits)

        # Compute the log probs of the sampled actions
        log_probs = self.pi_head.log_prob(pi_logits, actions)

        # Denormalize vpreds (value predictions) for further usage
        vpreds = outputs["vpreds"]
        vpreds = self.vf_head.denormalize(vpreds)

        # Update the output with actions and log probs
        outputs.update({
            "actions": actions,
            "log_probs": log_probs,
            "vpreds": vpreds
        })

        return outputs
    
    def encode(self, obs: th.Tensor) -> th.Tensor:
        # Pass the observation through the encoder and linear layer
        x = self.enc(obs)
        x = self.linear(x)
        return x

    def compute_losses(
        self,
        obs: th.Tensor,
        actions: th.Tensor,
        log_probs: th.Tensor,
        v_meta: th.Tensor,  # v_meta from storage
        advs: th.Tensor,
        rewards: th.Tensor = None,
        v_meta_next: th.Tensor = None,  # Next state v_meta for Q value computation
        pi_meta_next: th.Tensor = None,  # Next state policy π_meta for Q value computation
        pi_meta_logits: th.Tensor = None,  # Add π_meta logits for KL computation
        imitation_phase: bool = False,  # Whether it's the imitation phase
        clip_param: float = 0.2,
        **kwargs,
    ) -> Dict[str, th.Tensor]:
        # Pass through model
        outputs = self.forward(
            obs, rewards=rewards, v_meta_next=v_meta_next, pi_meta_next=pi_meta_next, pi_meta_logits=pi_meta_logits, **kwargs
        )

        losses = {}

        if imitation_phase:
            # Compute v_meta loss (imitation learning) using current v_meta and PPO-generated v
            v_pred = outputs["vpreds"]  # PPO-generated v from the value function head
            v_meta_loss = self.v_meta_head.mse_loss(v_meta, v_pred).mean()  # Compare v_meta with v_pred

            # Optionally, include KL divergence between pi and pi_meta if needed
            pi_logits = outputs["pi_logits"]
            kl_div = self.pi_head.kl_divergence(pi_logits, pi_meta_logits).mean()

            # Add imitation learning losses to the losses dictionary
            losses["v_meta_loss"] = v_meta_loss
            losses["pi_kl_loss"] = kl_div  # Optional KL loss for policy

        else:
            # PPO phase: compute policy loss
            pi_logits = outputs["pi_logits"]
            new_log_probs = self.pi_head.log_prob(pi_logits, actions)
            ratio = th.exp(new_log_probs - log_probs)
            ratio_clipped = th.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param)
            pi_loss = -th.min(advs * ratio, advs * ratio_clipped).mean()

            # Compute entropy
            entropy = self.pi_head.entropy(pi_logits).mean()

            # Compute value loss using v_meta from storage as the value target
            vf_loss = self.vf_head.mse_loss(outputs["vpreds"], v_meta).mean()

            # Add PPO losses to the losses dictionary
            losses["pi_loss"] = pi_loss
            losses["vf_loss"] = vf_loss
            losses["entropy"] = entropy

        return losses
