from __future__ import annotations

import torch
import torch.nn as nn
from typing import Any
from models import resnet


def normc_initializer(std: float = 1.0) -> Any:
    def initializer(tensor):
        tensor.data.normal_(0, 1)
        tensor.data *= std / torch.sqrt(tensor.data.pow(2).sum(1, keepdim=True))

    return initializer


def _create_convolutional_layers(in_channel, conv_filters, embedding_size):
    layers = []
    if isinstance(conv_filters, str):
        # Keep compatibility with resnet-based environments
        layers.append(resnet.create_convolutional_layers(conv_filters, in_channel, embedding_size))
    elif isinstance(conv_filters, list):
        prev_out = in_channel
        for layer_type, spec in conv_filters:
            if layer_type == 'conv2d':
                out_channel, kernel, stride, padding = spec
                conv = nn.Conv2d(prev_out, out_channel, kernel, stride, padding=padding)
                nn.init.xavier_uniform_(conv.weight)
                nn.init.constant_(conv.bias, 0)
                layers.append(conv)
                prev_out = out_channel
            elif layer_type == 'maxpool2d':
                kernel, stride = spec
                layers.append(nn.MaxPool2d(kernel_size=kernel, stride=stride))
            elif layer_type == 'adpavgpool2d':
                layers.append(nn.AdaptiveAvgPool2d(spec))
            elif layer_type == 'relu':
                layers.append(nn.ReLU())
            elif layer_type == 'leaky_relu':
                layers.append(nn.LeakyReLU())
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")

    return nn.ModuleList(layers)


def _create_dense_layers(sizes, activation_type=nn.ReLU, initializer=normc_initializer):
    layers = []
    for idx, (in_size, out_size) in enumerate(sizes):
        layers.append(nn.Linear(in_size, out_size))

        if initializer is not None:
            initializer(layers[-1].weight)

        if activation_type is not None:
            layers.append(activation_type())

    return nn.ModuleList(layers)


class Actor(nn.Module):
    def __init__(self, config, shared_feature_extractor=None, using_recipe=False):
        super().__init__()
        self.use_shared_network = shared_feature_extractor is not None
        self.action_masking = config['model']['custom_model_config']['action_masking']
        self.use_d2rl = config['model']['custom_model_config'].get('use_d2rl', False)
        use_leaky_relu = config['model']['custom_model_config'].get('use_leaky_relu', False)
        activation_type = nn.LeakyReLU if use_leaky_relu else nn.ReLU

        if self.use_shared_network:
            self.shared_feature_extractor = shared_feature_extractor
        else:
            self.conv_layers = nn.ModuleList(
                _create_convolutional_layers(
                    config['model']['custom_model_config']['input_conv_channels'],
                    config['model']['custom_model_config']['conv_filters'],
                    config['model']['custom_model_config']['actor_layer_sizes'][0][0]
                )
            )
            self.dense_layers = _create_dense_layers(
                config['model']['custom_model_config']['actor_layer_sizes'][:-1],
                activation_type=activation_type,
            )

        self.final_layer = nn.Linear(
            config['model']['custom_model_config']['actor_layer_sizes'][-1][0],  # 64
            config['model']['custom_model_config']['actor_layer_sizes'][-1][1]  # 6
        )
        self.env = config['env']
        if using_recipe:
            print(f"using recipe")
            self.recipe_proj = nn.Linear(3, 32)
        else:
            print(f"not using recipe")

    def forward(self, obs, recipe=None, action_mask=None):
        # print(recipe.shape)
        # print(obs.shape)
        if len(obs.shape) > 3:
            *batch, C, H, W = obs.shape
            x = obs.flatten(0, len(batch) - 1) if len(batch) > 1 else obs
        elif len(obs.shape) == 3:
            *batch, D = obs.shape
            x = obs.flatten(0, 1)
        else:
            *batch, D = obs.shape
            x = obs
        x = x.float()

        if self.use_shared_network:
            x = self.shared_feature_extractor(x)
        else:
            for layer in self.conv_layers:
                x = layer(x)
            x = x.reshape(x.shape[0], -1)

            # Concatenate recipe
            if recipe is not None:
                recipe = recipe.float().view(-1, recipe.shape[-1])  # [batch_size, 2]
                if recipe.shape[0] != x.shape[0]:
                    raise ValueError(f"Batch mismatch: x = {x.shape}, recipe = {recipe.shape}")

                projected_recipe = self.recipe_proj(recipe)  # [batch_size, 32]
                x = torch.cat([x, projected_recipe], dim=-1)

            conv_out = x.clone()

            for i, layer in enumerate(self.dense_layers):
                if self.use_d2rl and i > 0:
                    x = torch.cat([x, conv_out], dim=-1)
                x = layer(x)

        x = self.final_layer(x)
        raw_logits = x.clone()

        if self.action_masking and action_mask is not None:
            inf_mask = torch.clamp(torch.log(action_mask), min=-3.4e38)
            logits = x + inf_mask
        else:
            logits = raw_logits

        eng = torch.log(torch.exp(logits).sum(-1))
        raw_eng = torch.log(torch.exp(raw_logits).sum(-1))

        return logits, raw_logits, eng, raw_eng


class Critic(nn.Module):
    def __init__(self, config, shared_feature_extractor=None, using_recipe=False):
        super().__init__()
        self.use_shared_network = shared_feature_extractor is not None
        self.use_d2rl = config['model']['custom_model_config'].get('use_d2rl', False)
        use_leaky_relu = config['model']['custom_model_config'].get('use_leaky_relu', False)
        activation_type = nn.LeakyReLU if use_leaky_relu else nn.ReLU

        if self.use_shared_network:
            self.shared_feature_extractor = shared_feature_extractor
        else:
            self.conv_layers = nn.ModuleList(
                _create_convolutional_layers(
                    config['model']['custom_model_config']['input_conv_channels'],
                    config['model']['custom_model_config']['conv_filters'],
                    config['model']['custom_model_config']['critic_layer_sizes'][0][0]
                )
            )
            self.dense_layers = _create_dense_layers(
                config['model']['custom_model_config']['critic_layer_sizes'][:-1],
                activation_type=activation_type,
            )
        if using_recipe:
            self.recipe_proj = nn.Linear(3, 32)

        self.final_layer = nn.Linear(
            config['model']['custom_model_config']['critic_layer_sizes'][-1][0],  # 64
            config['model']['custom_model_config']['critic_layer_sizes'][-1][1]  # 1
        )

    def forward(self, obs, recipe=None):
        if len(obs.shape) > 3:
            *batch, C, H, W = obs.shape
            x = obs.flatten(0, len(batch) - 1) if len(batch) > 1 else obs
        elif len(obs.shape) == 3:
            *batch, D = obs.shape
            x = obs.flatten(0, 1)
        else:
            *batch, D = obs.shape
            x = obs

        x = x.float()

        if self.use_shared_network:
            x = self.shared_feature_extractor(x)
        else:
            for layer in self.conv_layers:
                x = layer(x)
            x = x.reshape(x.shape[0], -1)

            # Concatenate recipe
            if recipe is not None:
                recipe = recipe.float().view(-1, recipe.shape[-1])  # [batch_size, 2]
                if recipe.shape[0] != x.shape[0]:
                    raise ValueError(f"Batch mismatch: x = {x.shape}, recipe = {recipe.shape}")

                projected_recipe = self.recipe_proj(recipe)  # [batch_size, 32]
                x = torch.cat([x, projected_recipe], dim=-1)

            for layer in self.dense_layers:
                x = layer(x)

        x = self.final_layer(x)

        if len(batch) > 1:
            x = x.unflatten(0, batch)

        return x


class SharedFeatureExtractor(nn.Module):
    def __init__(self, config):
        super().__init__()

        use_leaky_relu = config['model']['custom_model_config'].get('use_leaky_relu', False)
        activation_type = nn.LeakyReLU if use_leaky_relu else nn.ReLU

        self.conv_layers = nn.ModuleList(
            _create_convolutional_layers(
                config['model']['custom_model_config']['input_conv_channels'],
                config['model']['custom_model_config']['conv_filters'],
                config['model']['custom_model_config']['actor_layer_sizes'][0][0]  # Shared embedding size
            )
        )

        # Shared dense layers before final layers (excluding last layer)
        self.shared_dense = _create_dense_layers(
            config['model']['custom_model_config']['actor_layer_sizes'][:-1],
            activation_type=activation_type,  # Ensure activation type consistency
        )

    def forward(self, obs):
        x = obs.float()
        for layer in self.conv_layers:
            x = layer(x)

        x = x.reshape(x.shape[0], -1)
        for layer in self.shared_dense:
            x = layer(x)
        return x


from torchrl.modules import ProbabilisticActor
from tensordict.nn.common import dispatch
from tensordict.tensordict import TensorDictBase
from typing import Any, Optional
from tensordict.nn.common import dispatch, TensorDictModule
from tensordict.nn.utils import set_skip_existing
from tensordict.tensordict import TensorDictBase
from tensordict.utils import NestedKey

from typing import Optional, Sequence, Union

import torch

from tensordict import TensorDictBase
from tensordict.nn import (
    dispatch,
    TensorDictModule
)
from tensordict.utils import NestedKey
from torch import nn

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.tensordict_module.probabilistic import (
    SafeProbabilisticModule
)


class SeparateDisretizedProbabilisticActor(ProbabilisticActor):
    def __init__(
            self,
            module: TensorDictModule,
            in_keys: Union[NestedKey, Sequence[NestedKey]],
            out_keys: Optional[Sequence[NestedKey]] = None,
            *,
            spec: Optional[TensorSpec] = None,
            action_mapping: Optional[torch.Tensor] = None,
            **kwargs,
    ):
        super().__init__(module, in_keys, out_keys=out_keys, spec=spec, **kwargs)
        self.action_mapping = nn.Parameter(action_mapping, requires_grad=False)

    @dispatch(auto_batch_size=False)
    @set_skip_existing(None)
    def forward(
            self,
            tensordict: TensorDictBase,
            tensordict_out: TensorDictBase | None = None,
            **kwargs: Any,
    ) -> TensorDictBase:
        if not len(kwargs):
            for module in self.module:
                tensordict = self._run_module(module, tensordict, **kwargs)
        else:
            raise RuntimeError(
                f"TensorDictSequential does not support keyword arguments other than 'tensordict_out' or in_keys: {self.in_keys}. Got {kwargs.keys()} instead."
            )

        # tensordict['action'] in shape batch x action_dim x discret size
        # it's a binary tensor where 1 denotes the chosen action
        action_cls = tensordict['action'].argmax(-1)

        # action mapping in shape action_dim x discret size
        tensordict['action'] = self.action_mapping[
            torch.arange(action_cls.shape[-1]),
            action_cls
        ]
        if tensordict_out is not None:
            tensordict_out.update(tensordict, inplace=True)
            return tensordict_out
        return tensordict


class DisretizedProbabilisticActor(ProbabilisticActor):
    def __init__(
            self,
            module: TensorDictModule,
            in_keys: Union[NestedKey, Sequence[NestedKey]],
            out_keys: Optional[Sequence[NestedKey]] = None,
            *,
            spec: Optional[TensorSpec] = None,
            action_mapping: Optional[torch.Tensor] = None,
            **kwargs,
    ):
        super().__init__(module, in_keys, out_keys=out_keys, spec=spec, **kwargs)
        self.action_mapping = nn.Parameter(action_mapping, requires_grad=False)

    @dispatch(auto_batch_size=False)
    @set_skip_existing(None)
    def forward(
            self,
            tensordict: TensorDictBase,
            tensordict_out: TensorDictBase | None = None,
            **kwargs: Any,
    ) -> TensorDictBase:
        if not len(kwargs):
            for module in self.module:
                tensordict = self._run_module(module, tensordict, **kwargs)
        else:
            raise RuntimeError(
                f"TensorDictSequential does not support keyword arguments other than 'tensordict_out' or in_keys: {self.in_keys}. Got {kwargs.keys()} instead."
            )
        tensordict['action'] = self.action_mapping[tensordict['action'].argmax(-1)]
        if tensordict_out is not None:
            tensordict_out.update(tensordict, inplace=True)
            return tensordict_out
        return tensordict