import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .distributions import Categorical
from .common import *


class MultigridNetwork(DeviceAwareModule):
    """
    Actor-Critic module (grid-size invariant pre-RNN features)
    """

    def __init__(
        self,
        observation_space,
        action_space,
        actor_fc_layers=(32, 32),
        value_fc_layers=(32, 32),
        conv_filters=16,
        conv_kernel_size=3,
        scalar_fc=5,
        scalar_dim=4,
        random_z_dim=0,
        xy_dim=0,
        recurrent_arch="lstm",
        recurrent_hidden_size=256,
        random=False,
    ):
        super(MultigridNetwork, self).__init__()

        self.random = random
        self.action_space = action_space
        num_actions = action_space.n

        # ===== Image embeddings (adaptive to HxW, constant output width) =====
        obs_shape = observation_space["image"].shape  # channel-last in this codebase
        m = obs_shape[-2]  # height
        n = obs_shape[-1]  # width
        c = obs_shape[-3]  # channels

        # Use real channel count 'c' (not hard-coded 3)
        self.conv_filters = conv_filters
        self.image_conv = nn.Sequential(
            Conv2d_tf(
                c, conv_filters, kernel_size=conv_kernel_size, stride=1, padding="valid"
            ),
            nn.ReLU(inplace=True),
        )
        # We will do layout-agnostic global avg pooling in forward:
        # output will always be [B, conv_filters]
        self.image_embedding_size = conv_filters
        self.preprocessed_input_size = self.image_embedding_size

        # ===== x, y positional one-hots (optional, consistent) =====
        self.xy_dim = xy_dim
        if self.xy_dim:
            # we concatenate x_oh and y_oh directly (2 * xy_dim)
            self.preprocessed_input_size += 2 * self.xy_dim

        # ===== Scalar embedding (optional) =====
        self.scalar_dim = scalar_dim
        if self.scalar_dim:
            self.scalar_embed = nn.Linear(self.scalar_dim, scalar_fc)
            self.preprocessed_input_size += scalar_fc
        else:
            self.scalar_embed = None

        # ===== Random z (optional) =====
        self.random_z_dim = random_z_dim
        if self.random_z_dim:
            self.preprocessed_input_size += self.random_z_dim

        self.base_output_size = self.preprocessed_input_size

        # ===== RNN core (optional) =====
        self.rnn = None
        if recurrent_arch:
            self.rnn = RNN(
                input_size=self.preprocessed_input_size,
                hidden_size=recurrent_hidden_size,
                arch=recurrent_arch,
            )
            self.base_output_size = recurrent_hidden_size

        # ===== Heads =====
        self.actor = nn.Sequential(
            make_fc_layers_with_hidden_sizes(
                actor_fc_layers, input_size=self.base_output_size
            ),
            Categorical(actor_fc_layers[-1], num_actions),
        )
        self.critic = nn.Sequential(
            make_fc_layers_with_hidden_sizes(
                value_fc_layers, input_size=self.base_output_size
            ),
            init_(nn.Linear(value_fc_layers[-1], 1)),
        )

        apply_init_(self.modules())
        self.train()

    @property
    def is_recurrent(self):
        return self.rnn is not None

    @property
    def recurrent_hidden_state_size(self):
        if self.rnn is not None:
            return self.rnn.recurrent_hidden_state_size
        else:
            return 0

    def forward(self, inputs, rnn_hxs, masks):
        raise NotImplementedError

    def _global_avg_pool_layout_agnostic(self, feat: torch.Tensor) -> torch.Tensor:
        """
        Accepts conv feature either NHWC (B,H,W,C) or NCHW (B,C,H,W),
        returns [B, C] by global-average-pooling spatial dims.
        """
        if feat.dim() != 4:
            raise RuntimeError(f"Expected 4D conv output, got {feat.shape}")

        # Try to infer layout by matching the channel dimension to conv_filters
        if feat.size(-1) == self.conv_filters and feat.size(1) != self.conv_filters:
            # NHWC
            return feat.mean(dim=(1, 2))  # [B, C]
        elif feat.size(1) == self.conv_filters:
            # NCHW
            return F.adaptive_avg_pool2d(feat, 1).flatten(1)  # [B, C]
        else:
            # Fallback: assume NCHW
            return F.adaptive_avg_pool2d(feat, 1).flatten(1)

    def _forward_base(self, inputs, rnn_hxs, masks):
        # ---- Unpack inputs ----
        image = inputs.get("image")

        scalar = inputs.get("direction")
        if scalar is None:
            scalar = inputs.get("time_step")

        x = inputs.get("x")
        y = inputs.get("y")

        in_z = inputs.get("random_z", None)

        # ---- Encode image to fixed width [B, conv_filters] ----
        feat = self.image_conv(image)  # could be NHWC or NCHW depending on Conv2d_tf
        in_image = self._global_avg_pool_layout_agnostic(feat)  # [B, conv_filters]
        B = in_image.size(0)

        # ---- Encode XY one-hots (if enabled) ----
        if self.xy_dim:
            assert x is not None and y is not None, "xy_dim>0 but x/y not provided"
            x_oh = one_hot(self.xy_dim, x, device=self.device)  # [B, xy_dim]
            y_oh = one_hot(self.xy_dim, y, device=self.device)  # [B, xy_dim]
            in_xy = torch.cat([x_oh, y_oh], dim=-1)  # [B, 2*xy_dim]
        else:
            in_xy = torch.empty(B, 0, device=self.device)

        # ---- Encode scalar (one-hot -> linear) ----
        if self.scalar_embed is not None:
            assert scalar is not None, "scalar_dim>0 but scalar not provided"
            scalar_oh = one_hot(self.scalar_dim, scalar).to(self.device)  # [B, scalar_dim]
            in_scalar = self.scalar_embed(scalar_oh)  # [B, scalar_fc]
        else:
            in_scalar = torch.empty(B, 0, device=self.device)

        # ---- Random z ----
        if self.random_z_dim:
            if in_z is None:
                in_z = torch.zeros(B, self.random_z_dim, device=self.device)
            else:
                # coerce to [B, random_z_dim]
                if in_z.dim() == 1:
                    in_z = in_z.unsqueeze(0).expand(B, -1)
                elif (
                    in_z.dim() == 2
                    and in_z.size(0) != B
                    and in_z.size(1) == self.random_z_dim
                ):
                    in_z = in_z.expand(B, -1)
                in_z = in_z.to(self.device)
            assert in_z.size(-1) == self.random_z_dim, "random_z last dim mismatch"
        else:
            in_z = torch.empty(B, 0, device=self.device)

        # ---- Concatenate all features (fixed width) ----
        in_embedded = torch.cat((in_image, in_xy, in_scalar, in_z), dim=-1)  # [B, F]

        # ---- Core ----
        if self.rnn is not None:
            core_features, rnn_hxs = self.rnn(in_embedded, rnn_hxs, masks)
        else:
            core_features = in_embedded

        return core_features, rnn_hxs

    def act(self, inputs, rnn_hxs, masks, deterministic=False):
        if self.random:
            B = inputs["image"].shape[0]
            action = torch.zeros((B, 1), dtype=torch.int64, device=self.device)
            values = torch.zeros((B, 1), device=self.device)
            action_log_dist = torch.ones(B, self.action_space.n, device=self.device)
            for b in range(B):
                action[b] = self.action_space.sample()
            return values, action, action_log_dist, rnn_hxs

        core_features, rnn_hxs = self._forward_base(inputs, rnn_hxs, masks)
        dist = self.actor(core_features)
        value = self.critic(core_features)

        action = dist.mode() if deterministic else dist.sample()
        action_log_dist = dist.logits
        return value, action, action_log_dist, rnn_hxs

    def get_value(self, inputs, rnn_hxs, masks):
        core_features, _ = self._forward_base(inputs, rnn_hxs, masks)
        return self.critic(core_features)

    def evaluate_actions(self, inputs, rnn_hxs, masks, action, return_policy_logits=False):
        core_features, rnn_hxs = self._forward_base(inputs, rnn_hxs, masks)
        dist = self.actor(core_features)
        value = self.critic(core_features)
        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        if return_policy_logits:
            return value, action_log_probs, dist_entropy, rnn_hxs, dist
        return value, action_log_probs, dist_entropy, rnn_hxs
