from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from booml.common import layers, math, init

class WorldModel(nn.Module):
    """
    TD-MPC2 implicit world model architecture.
    Can be used for both single-task and multi-task experiments.
    """

    def __init__(self, cfg, device=None):
        super().__init__()
        self.cfg = cfg
        if cfg.multitask:
            self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
            self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
            for i in range(len(cfg.tasks)):
                self._action_masks[i, : cfg.action_dims[i]] = 1.0
        self._encoder = layers.enc(cfg)
        self._dynamics = layers.mlp(
            cfg.latent_dim + cfg.action_dim + cfg.task_dim,
            2 * [cfg.mlp_dim],
            cfg.latent_dim,
            act=layers.SimNorm(cfg),
        )
        self._reward = layers.mlp(
            cfg.latent_dim + cfg.action_dim + cfg.task_dim,
            2 * [cfg.mlp_dim],
            max(cfg.num_bins, 1),
        )
        self._pi = layers.mlp(
            cfg.latent_dim + cfg.task_dim, 2 * [cfg.mlp_dim], 2 * cfg.action_dim
        )
        self._Qs = layers.Ensemble(
            [
                layers.mlp(
                    cfg.latent_dim + cfg.action_dim + cfg.task_dim,
                    2 * [cfg.mlp_dim],
                    max(cfg.num_bins, 1),
                    dropout=cfg.dropout,
                )
                for _ in range(cfg.num_q)
            ]
        )

        self._Vs = layers.Ensemble(
            [
                layers.mlp(
                    cfg.latent_dim + cfg.task_dim,
                    2 * [cfg.mlp_dim],
                    max(cfg.num_bins, 1),
                    dropout=cfg.dropout,
                )
                for _ in range(cfg.num_v)
            ]
        )
        
        # Score function for Langevin dynamics
        score_input_dim = cfg.latent_dim + cfg.action_dim
        if cfg.multitask:
            score_input_dim += cfg.task_dim
        score_mlp_dims = getattr(cfg, 'score_mlp_dims', 2 * [cfg.mlp_dim])
        self._score_function = layers.mlp(
            score_input_dim,
            score_mlp_dims,
            cfg.action_dim,
            act=None,
        )
        
        self.apply(init.weight_init)
        init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
        self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
        self._target_Vs = deepcopy(self._Vs).requires_grad_(False)
        self.log_std_min = torch.tensor(cfg.log_std_min)
        self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min

        # use torch.compile to speed up models
        self._encoder = torch.compile(self._encoder['state'])
        self._dynamics = torch.compile(self._dynamics)
        self._reward = torch.compile(self._reward)
        self._pi = torch.compile(self._pi)
        self._Qs = torch.compile(self._Qs)
        self._Vs = torch.compile(self._Vs)
        self._target_Qs = torch.compile(self._target_Qs)
        self._target_Vs = torch.compile(self._target_Vs)
        self._task_emb = torch.compile(self._task_emb) if cfg.multitask else None
        self._score_function = torch.compile(self._score_function)

    @property
    def total_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def to(self, *args, **kwargs):
        """
        Overriding `to` method to also move additional tensors to device.
        """
        super().to(*args, **kwargs)
        if self.cfg.multitask:
            self._action_masks = self._action_masks.to(*args, **kwargs)
        self.log_std_min = self.log_std_min.to(*args, **kwargs)
        self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
        return self

    def train(self, mode=True):
        """
        Overriding `train` method to keep target Q-networks in eval mode.
        """
        super().train(mode)
        self._target_Qs.train(False)
        return self

    def track_q_grad(self, mode=True):
        """
        Enables/disables gradient tracking of Q-networks.
        Avoids unnecessary computation during policy optimization.
        This method also enables/disables gradients for task embeddings.
        """
        for p in self._Qs.parameters():
            p.requires_grad_(mode)
        if self.cfg.multitask:
            for p in self._task_emb.parameters():
                p.requires_grad_(mode)

    def track_v_grad(self, mode=True):
        """
        Enables/disables gradient tracking of V-networks.
        Avoids unnecessary computation during policy optimization.
        This method also enables/disables gradients for task embeddings.
        """
        for p in self._Vs.parameters():
            p.requires_grad_(mode)
        if self.cfg.multitask:
            for p in self._task_emb.parameters():
                p.requires_grad_(mode)

    def soft_update_target_Q(self):
        """
        Soft-update target Q-networks using Polyak averaging.
        """
        with torch.no_grad():
            for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
                p_target.data.lerp_(p.data, self.cfg.tau)

    def soft_update_target_V(self):
        """
        Soft-update target V-networks using Polyak averaging.
        """
        with torch.no_grad():
            for p, p_target in zip(self._Vs.parameters(), self._target_Vs.parameters()):
                p_target.data.lerp_(p.data, self.cfg.tau)

    def task_emb(self, x, task):
        """
        Continuous task embedding for multi-task experiments.
        Retrieves the task embedding for a given task ID `task`
        and concatenates it to the input `x`.
        """
        if isinstance(task, int):
            task = torch.tensor([task], device=x.device)
        emb = self._task_emb(task.long())
        if x.ndim == 3:
            emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1)
        elif emb.shape[0] == 1:
            emb = emb.repeat(x.shape[0], 1)
        return torch.cat([x, emb], dim=-1)

    def encode(self, obs, task):
        """
        Encodes an observation into its latent representation.
        This implementation assumes a single state-based observation.
        """
        if self.cfg.multitask:
            obs = self.task_emb(obs, task)
        if self.cfg.obs == "rgb" and obs.ndim == 5:
            return torch.stack([self._encoder[self.cfg.obs](o) for o in obs])
        # print(self.cfg.obs) # state
        return self._encoder(obs)

    def next(self, z, a, task):
        """
        Predicts the next latent state given the current latent state and action.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        z = torch.cat([z, a], dim=-1)
        return self._dynamics(z)

    def reward(self, z, a, task):
        """
        Predicts instantaneous (single-step) reward.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        z = torch.cat([z, a], dim=-1)
        y = self._reward(z)
        # print("reward_head_out stats:",
        #     y.shape, y.min().item(), y.max().item(), y.abs().mean().item(), y.std().item())

        return y

    def pi(self, z, task):
        """
        Samples an action from the policy prior.
        The policy prior is a Gaussian distribution with
        mean and (log) std predicted by a neural network.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)

        # Gaussian policy prior
        mu, log_std = self._pi(z).chunk(2, dim=-1)
        log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
        eps = torch.randn_like(mu)

        if self.cfg.multitask:  # Mask out unused action dimensions
            mu = mu * self._action_masks[task]
            log_std = log_std * self._action_masks[task]
            eps = eps * self._action_masks[task]
            action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
        else:  # No masking
            action_dims = None

        log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
        pi = mu + eps * log_std.exp()
        mu, pi, log_pi = math.squash(mu, pi, log_pi)

        return mu, pi, log_pi, log_std
    
    def log_pi_action(self, z, a, task):
        """
        Compute the log probability of an action sequence given the latent states.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        mu, log_std = self._pi(z).chunk(2, dim=-1)
        eps = (a - mu) / (log_std.exp() + 1e-8)

        if self.cfg.multitask:  # Mask out unused action dimensions
            mu = mu * self._action_masks[task]
            log_std = log_std * self._action_masks[task]
            eps = eps * self._action_masks[task]
            action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
        else:  # No masking
            action_dims = None
            
        log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
        return log_pi

    def Q(self, z, a, task, return_type="min", target=False):
        """
        Predict state-action value.
        `return_type` can be one of [`min`, `avg`, `all`]:
                - `min`: return the minimum of two randomly subsampled Q-values.
                - `avg`: return the average of two randomly subsampled Q-values.
                - `max`: return the maximum of two randomly subsampled Q-values.
                - `all`: return all Q-values.
        `target` specifies whether to use the target Q-networks or not.
        """
        assert return_type in {"min", "avg", "all", "max"}

        if self.cfg.multitask:
            z = self.task_emb(z, task)

        z = torch.cat([z, a], dim=-1)
        out = (self._target_Qs if target else self._Qs)(z)

        if return_type == "all":
            return out

        Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)]
        Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg)

        if return_type == "min":
            return torch.min(Q1, Q2)
        elif return_type == "avg":
            return (Q1 + Q2) / 2
        elif return_type == "max":
            qs_thot = [math.two_hot_inv(q, self.cfg) for q in out]
            qs_thot = torch.stack(qs_thot, dim=0)
            return torch.max(qs_thot, dim=0)[0]

    def V(self, z, task, return_type="min", target=False):
        """
        Predict state value.
        `return_type` can be one of [`min`, `avg`, `all`]:
                - `min`: return the minimum of two randomly subsampled V-values.
                - `avg`: return the average of two randomly subsampled V-values.
                - `max`: return the maximum of two randomly subsampled V-values.
                - `all`: return all V-values.
        `target` specifies whether to use the target V-networks or not.
        """
        assert return_type in {"min", "avg", "all", "max"}
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        out = (self._target_Vs if target else self._Vs)(z)
        if return_type == "all":
            return out
        V1, V2 = out[np.random.choice(self.cfg.num_v, 2, replace=False)]
        V1, V2 = math.two_hot_inv(V1, self.cfg), math.two_hot_inv(V2, self.cfg)
        
        if return_type == "min":
            return torch.min(V1, V2)
        elif return_type == "avg":
            return (V1 + V2) / 2
        elif return_type == "max":
            vs_thot = [math.two_hot_inv(v, self.cfg) for v in out]
            vs_thot = torch.stack(vs_thot, dim=0)
            return torch.max(vs_thot, dim=0)[0]
