from typing import Tuple, Union

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import dispatch
from torchrl.objectives import ClipPPOLoss
from torch.distributions.kl import kl_divergence
from torch.distributions.categorical import Categorical
import contextlib
from torchrl.objectives.utils import distance_loss
from torch import distributions as d
from utils import SeparateDisretizedProbabilisticActor, DisretizedProbabilisticActor

class WeightedClipPPOLoss(ClipPPOLoss):
    def __init__(
        self,
        *args,
        kl_coeff=0,
        kl_target=0.01,
        vf_clip_param=None,
        linear_entropy=None,
        id_file=None,
        ood_file=None,
        margin_in=12,
        margin_out = 14,
        lambda_energy=0.0001,
        device=torch.device("cpu"),
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.kl_coeff = kl_coeff
        self.kl_target = kl_target
        self.vf_clip_param = vf_clip_param
        self.lambda_energy = lambda_energy
        self.linear_entropy = linear_entropy
        self.margin_in = margin_in
        self.margin_out = margin_out
        self.device = device
        print(self.lambda_energy)
        print(self.margin_in)
        print(self.margin_out)

        if id_file is not None:
            self.id_data_loaded = True
            id_batch = torch.load(id_file, map_location=device)  # [N, 3, 9, 9]

            if isinstance(id_batch, dict):  # already stacked dict: {'image': ..., 'recipe': ...}
                self.id_batch = TensorDict(
                    {k: v.to(device) for k, v in id_batch.items()},
                    batch_size=next(iter(id_batch.values())).shape[0],
                    device=device
                )
            else:  # tensor of shape [N, C, H, W]
                self.id_batch = TensorDict(
                    {"image": id_batch},
                    batch_size=id_batch.shape[0],
                    device=device
                )


        else:
            self.id_data_loaded = False

        if ood_file is not None:
            self.ood_data_loaded = True
            ood_batch = torch.load(ood_file, map_location=device)  # [N, 3, 9, 9]

            if isinstance(ood_batch, dict):
                self.ood_batch = TensorDict(
                    {k: v.to(device) for k, v in ood_batch.items()},
                    batch_size=next(iter(ood_batch.values())).shape[0],
                    device=device
                )
            else:
                self.ood_batch = TensorDict(
                    {"image": ood_batch},
                    batch_size=ood_batch.shape[0],
                    device=device
                )
        else:
            self.ood_data_loaded = False

        if self.linear_entropy is not None:
            start, _ = self.linear_entropy
            new_val = torch.tensor(start, dtype=self.entropy_coef.dtype, device=self.entropy_coef.device)
            self.entropy_coef.data.copy_(new_val)

    def compute_energy_loss(self):
        # if self.id_data_loaded:
        #     print(f"margin_in {self.margin_in} in device {self.device}")
        # if self.ood_data_loaded:
        #     print(f"margin_out {self.margin_out} in device {self.device}")

        was_training = self.actor_network.training
        self.actor_network.eval()

        with (self.actor_network_params.to_module(self.actor_network)
              if self.functional else contextlib.nullcontext()):
            if self.id_data_loaded:
                self.actor_network(self.id_batch)
            if self.ood_data_loaded:
                self.actor_network(self.ood_batch)

        loss_energy = 0.0

        if self.id_data_loaded:
            raw_eng_in = self.id_batch["raw_energy"]
            loss_in = torch.mean(torch.relu(self.margin_in - raw_eng_in) ** 2)
            loss_energy += loss_in

        if self.ood_data_loaded:
            # print("OOD data loaded, computing energy loss")
            raw_eng_out = self.ood_batch["raw_energy"]
            loss_out = torch.mean(torch.relu(raw_eng_out - self.margin_out) ** 2)
            loss_energy += loss_out

        loss_energy = self.lambda_energy * loss_energy

        if was_training:
            self.actor_network.train()

        return loss_energy

    @dispatch
    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        tensordict = tensordict.clone(False)
        advantage = tensordict.get(self.tensor_keys.advantage, None)
        if advantage is None:
            self.value_estimator(
                tensordict,
                params=self._cached_critic_network_params_detached,
                target_params=self.target_critic_network_params,
            )
            advantage = tensordict.get(self.tensor_keys.advantage)
        if self.normalize_advantage and advantage.numel() > 1:
            loc = advantage.mean()
            scale = advantage.std().clamp_min(1e-6)
            advantage = (advantage - loc) / scale

        previous_dist = Categorical(logits=tensordict["logits"])
        log_weight, dist = self._log_weight(tensordict)

        with torch.no_grad():
            lw = log_weight.squeeze()
            ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
            batch = log_weight.shape[0]
        if ess.numel() > 1:
            ess = ess.mean()
        if len(log_weight.shape) > len(advantage.shape):
            advantage = advantage.unsqueeze(-1)

        is_ratio = log_weight.exp()
        gain1 = is_ratio * advantage
        log_weight_clip = log_weight.clamp(*self._clip_bounds)
        gain2 = log_weight_clip.exp() * advantage
        gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
        if len(gain.shape) == 3:
            gain = gain.mean(1)

        loss_obj = -gain.mean()
        unclipped_loss = -gain1.mean()
        clipped_loss = -gain2.mean()
        td_out = TensorDict({
            "loss_objective": loss_obj,
            "loss_unclipped": unclipped_loss.detach(),
            "loss_clipped": clipped_loss.detach(),
            "iso": is_ratio.mean().detach()
        }, [])

        if self.entropy_bonus:
            entropy = self.get_entropy_bonus(dist)
            td_out.set("entropy", entropy.mean().detach())
            td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
        if self.critic_coef:
            loss_critic = self.loss_critic(tensordict)
            weight = log_weight_clip.exp().detach()
            loss_critic = (weight * loss_critic).mean()
            td_out.set("loss_critic", loss_critic.mean())
        if self.kl_coeff:
            kl = kl_divergence(previous_dist, dist).mean()
        else:
            kl = torch.tensor(0.0, device=dist.logits.device)
        td_out.set("loss_kl", self.kl_coeff * kl)
        td_out.set("ESS", ess.mean() / batch)

        self.energy_penalty = self.id_data_loaded or self.ood_data_loaded
        if self.energy_penalty:
            loss_energy = self.compute_energy_loss()
            td_out.set("loss_energy", loss_energy)
        else:
            td_out.set("loss_energy", torch.tensor(0.0, device=dist.logits.device))

        # Add distillation loss from KickStart collector if present
        if "distill_loss" in tensordict.keys():
            loss_distill = tensordict["distill_loss"].mean()
        else:
            loss_distill = torch.tensor(0.0, device=tensordict.device)

        td_out.set("loss_distill", loss_distill)

        return td_out

    def compute_bc_loss(self, tensordict: TensorDictBase, bc_coef: float = 0.01) -> torch.Tensor:
        """Compute Behavior Cloning (BC) loss, scaled by bc_coef."""
        if "take_advice" not in tensordict.keys():
            return torch.tensor(0.0, device=tensordict.device)

        take_advice = tensordict["take_advice"]  # [B]
        executed_action = tensordict.get(self.tensor_keys.action)  # [B, A] one-hot

        with (self.actor_network_params.to_module(self.actor_network)
              if self.functional else contextlib.nullcontext()):
            dist = self.actor_network.get_dist(tensordict)

        pred_logits = dist.logits  # [B, A]
        action_idx = executed_action.argmax(dim=-1)  # [B]

        log_probs = torch.log_softmax(pred_logits, dim=-1)  # [B, A]
        logp_executed = log_probs.gather(-1, action_idx.unsqueeze(-1)).squeeze(-1)  # [B]

        bc_loss_all = -logp_executed  # negative log likelihood

        if take_advice.any():
            bc_loss = bc_loss_all[take_advice].mean()
        else:
            bc_loss = torch.tensor(0.0, device=logp_executed.device)

        return bc_coef * bc_loss

    def _log_weight(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, d.Distribution]:
        action = tensordict.get(self.tensor_keys.action)
        if action.requires_grad:
            raise RuntimeError(f"tensordict stored {self.tensor_keys.action} requires grad.")

        with (self.actor_network_params.to_module(self.actor_network)
              if self.functional else contextlib.nullcontext()):
            dist = self.actor_network.get_dist(tensordict)
            if isinstance(self.actor_network, SeparateDisretizedProbabilisticActor):
                action = (torch.abs(
                    action.unsqueeze(-1) - self.actor_network.action_mapping.unsqueeze(0)
                ) == 0).long()
            elif isinstance(self.actor_network, DisretizedProbabilisticActor):
                action = (torch.abs(
                    action.unsqueeze(1) - self.actor_network.action_mapping.unsqueeze(0)
                ).sum(-1) == 0).long()

        log_prob = dist.log_prob(action)
        prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
        if prev_log_prob.requires_grad:
            raise RuntimeError("tensordict prev_log_prob requires grad.")
        log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
        return log_weight, dist

    def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
        if self.separate_losses:
            print(f"separate_losses is true, detaching tensordict {self.separate_losses}")
            tensordict = tensordict.detach()
        try:
            target_return = tensordict.get(self.tensor_keys.value_target)
        except KeyError:
            raise KeyError(
                f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
                "Make sure you provided the right key and the value_target (i.e. the target return) has been retrieved accordingly."
            )
        with (self.critic_network_params.to_module(self.critic_network)
              if self.functional else contextlib.nullcontext()):
            state_value_td = self.critic_network(tensordict)
        try:
            state_value = state_value_td.get(self.tensor_keys.value)
        except KeyError:
            raise KeyError(
                f"the key {self.tensor_keys.value} was not found in the input tensordict. "
                "Make sure that the value_key passed to PPO is accurate."
            )
        loss_value = distance_loss(target_return, state_value, loss_function=self.loss_critic_type)
        if self.vf_clip_param:
            loss_value = torch.clamp(loss_value, 0, self.vf_clip_param)
        return self.critic_coef * loss_value

    def update_kl(self, scaled_kl):
        if self.kl_coeff == 0:
            return 0
        if self.kl_target < 0:
            return self.kl_coeff

        kl = scaled_kl / self.kl_coeff
        if kl > self.kl_target * 1.5:
            self.kl_coeff *= 2
        elif kl < self.kl_target / 1.5:
            self.kl_coeff *= 0.5
        return self.kl_coeff

    def set_entropy(self, current_step: int, total_steps: int):
        if self.linear_entropy is None:
            return
        start, end = self.linear_entropy
        fraction = min(current_step / total_steps, 1.0)
        new_value = start + fraction * (end - start)
        new_val = torch.tensor(new_value, dtype=self.entropy_coef.dtype, device=self.entropy_coef.device)
        self.entropy_coef.data.copy_(new_val)
        print(self.entropy_coef.data)

    def set_id_batch(self, image_tensor: Union[torch.Tensor, dict]):
        """Sets the 'id_batch' dynamically from archive_buffer.
        Accepts either a single image tensor [N, C, H, W] or a dict like {'image': ..., 'recipe': ...}.
        """
        if isinstance(image_tensor, dict):
            self.id_batch = TensorDict(
                {k: v.to(self.device) for k, v in image_tensor.items()},
                batch_size=next(iter(image_tensor.values())).shape[0],
                device=self.device
            )
        else:
            self.id_batch = TensorDict(
                {"image": image_tensor.to(self.device)},
                batch_size=image_tensor.shape[0],
                device=self.device
            )

        self.id_data_loaded = True





