from typing import Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange


def discretize_actions(
        actions: Union[torch.Tensor, np.ndarray],
        n_bins: int = 21,
        one_hot: bool = False) -> Union[torch.Tensor, np.ndarray]:
    if isinstance(actions, torch.Tensor):
        repr_values = torch.linspace(-1, 1, n_bins, device=actions.device)
        step = repr_values[1] - repr_values[0]
        bins = repr_values + step / 2
        discrete_actions = torch.bucketize(actions, bins)
        if one_hot:
            discrete_actions = torch.eye(
                n_bins, device=actions.device)[discrete_actions]

    else:
        # discretize vectors into n_bins
        repr_values = np.linspace(-1, 1, n_bins)
        step = repr_values[1] - repr_values[0]
        bins = repr_values + step / 2
        discrete_actions = np.digitize(actions, bins)
        if one_hot:
            discrete_actions = np.eye(n_bins)[discrete_actions]

    return discrete_actions


def interpret_discrete_actions(
        discrete_actions: Union[np.ndarray, torch.Tensor],
        n_bins: int = 21) -> Union[np.ndarray, torch.Tensor]:
    # convert discrete actions into continuous actions
    if isinstance(discrete_actions, torch.Tensor):
        discrete_actions = discrete_actions.long()
        repr_values = torch.linspace(-1,
                                     1,
                                     n_bins,
                                     device=discrete_actions.device)
    else:
        repr_values = np.linspace(-1, 1, n_bins)
    return repr_values[discrete_actions]


class Norm1Layer(nn.Module):
    """Normalize vectors to unit length."""

    def __init__(self):
        super().__init__()

    def forward(self, x, eps=1e-6):
        return x / (torch.norm(x, dim=-1, keepdim=True) + eps)


class MLP(torch.nn.Module):

    def __init__(
        self,
        in_dim: int,
        hid_dim: int,
        out_dim: int,
        n_hid_layers: int,
        activation: str = 'relu',
        out_activation: str = 'relu',
    ):
        super().__init__()
        layers = [
            nn.Linear(in_dim, hid_dim),
            self._get_activation(name=activation, dim=hid_dim)
        ]
        for _ in range(n_hid_layers - 1):
            layers += [
                nn.Linear(hid_dim, hid_dim),
                self._get_activation(name=activation, dim=hid_dim)
            ]
        layers += [
            nn.Linear(hid_dim, out_dim),
            self._get_activation(name=out_activation, dim=out_dim)
        ]

        self.net = nn.Sequential(*layers)

    @staticmethod
    def _get_activation(name: str, dim: int):
        if name == 'relu':
            return nn.ReLU()
        elif name == 'none':
            return nn.Identity()
        elif name == 'tanh':
            return nn.Tanh()
        elif name == 'sigmoid':
            return nn.Sigmoid()
        elif name == 'gelu':
            return nn.GELU()
        elif name == 'mish':
            return nn.Mish()
        elif name == 'leaky_relu':
            return nn.LeakyReLU(0.2)
        elif name == 'ln':
            return nn.LayerNorm(normalized_shape=dim, elementwise_affine=False)
        elif name == 'ln_affine':
            return nn.LayerNorm(normalized_shape=dim)
        elif name == 'norm1':
            return Norm1Layer()
        else:
            raise ValueError(f'Invalid activation={name} is detected.')

    def forward(self, x):
        return self.net(x)


class Policy(torch.nn.Module):

    def __init__(
        self,
        state_dim: int,
        cond_dim: int,
        out_dim: int,
        domain_dim: int,
        latent_dim: int,
        hid_dim: int,
        num_hidden_layers: Tuple[int, int, int] = (3, 5, 3),
        activation: str = 'relu',
        repr_activation: str = 'relu',
        decode_with_state: bool = True,
        image_observation: bool = False,
        image_state_dim: int = 50,
        pretrained: bool = False,
        use_coord_conv: bool = False,
        z_norm: str = '',
        discrete: bool = False,
        discrete_bins: int = 21,
        naive_bc: bool = False,
        use_image_decoder: bool = False,
        input_image_state_into_decoder: bool = True,
    ):
        super().__init__()
        self.state_dim = state_dim
        self.cond_dim = cond_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        self.latent_dim = latent_dim
        self.domain_dim = domain_dim
        self.decode_with_state = decode_with_state

        self.image_state_dim = image_state_dim
        self.image_observation = image_observation
        self.z_norm = z_norm
        image_latent_dim = 50

        self.discrete = discrete
        self.discrete_bins = discrete_bins
        self.naive_bc = naive_bc
        self.decode_with_state = decode_with_state
        self.input_image_state_into_decoder = input_image_state_into_decoder

        if self.image_observation:
            self.image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=image_state_dim,
                pretrained=pretrained,
                coord_conv=use_coord_conv,
            )
            if use_image_decoder:
                self.image_decoder = ConvDecoder(
                    image_latent_dim=image_state_dim)

        if self.naive_bc:
            self.encoder = self.decoder = None
            self.core = MLP(
                in_dim=state_dim + domain_dim + cond_dim,
                hid_dim=hid_dim,
                out_dim=out_dim *
                self.discrete_bins if self.discrete else out_dim,
                n_hid_layers=num_hidden_layers[1],
                activation=activation,
                out_activation='none' if self.discrete else 'tanh',
            )

        else:
            self.encoder = MLP(
                in_dim=state_dim + domain_dim,
                hid_dim=hid_dim,
                out_dim=latent_dim,
                n_hid_layers=num_hidden_layers[0],
                activation=activation,
                out_activation=self.z_norm if self.z_norm else repr_activation,
            )

            self.core = MLP(
                in_dim=latent_dim + cond_dim,
                hid_dim=hid_dim,
                out_dim=latent_dim,
                n_hid_layers=num_hidden_layers[1],
                activation=activation,
                out_activation=repr_activation,
            )

            head_in_dim = latent_dim + domain_dim
            if self.decode_with_state:
                head_in_dim += state_dim
                if not input_image_state_into_decoder:
                    head_in_dim -= self.image_state_dim
            self.head = MLP(
                in_dim=head_in_dim,
                hid_dim=hid_dim,
                out_dim=out_dim *
                self.discrete_bins if self.discrete else out_dim,
                n_hid_layers=num_hidden_layers[2],
                activation=activation,
                out_activation='none' if self.discrete else 'tanh',
            )

    def forward(self, s, c, d):
        # s: state, c: condition vector (e.g. task ID), d: domain ID

        if self.naive_bc:
            input_tensor = torch.cat((s, c, d), dim=-1)
            out = self.core(input_tensor)
            return out, None, None

        else:

            # encoding
            sd = torch.cat((s, d), dim=-1)
            z = self.encoder(sd)

            # core
            zc = torch.cat((z, c), dim=-1)
            alpha = self.core(zc)

            # decoding
            if self.decode_with_state:
                if not self.input_image_state_into_decoder:
                    s = s[..., :-self.image_state_dim]
                alpha_d = torch.cat((alpha, d, s), dim=-1)
            else:
                alpha_d = torch.cat((alpha, d), dim=-1)

            out = self.head(alpha_d)

            if self.discrete:
                out = rearrange(out, 'b (n d) -> b n d', n=self.out_dim)

            return out, z, alpha


class Discriminator(torch.nn.Module):

    def __init__(
        self,
        latent_dim: int,
        hid_dim: int,
        num_classes: int = 2,
        cond_dim: int = 2,
        num_hidden_layer: int = 4,
        task_cond: bool = False,
        activation: str = 'relu',
    ):
        super().__init__()
        self.task_cond = task_cond

        input_dim = latent_dim

        if self.task_cond:
            input_dim += cond_dim
        self.net = MLP(
            in_dim=input_dim,
            hid_dim=hid_dim,
            out_dim=num_classes,
            n_hid_layers=num_hidden_layer,
            activation=activation,
            out_activation='none',
        )

    def forward(self, z, c=None):
        if self.task_cond:
            z = torch.cat((z, c), dim=-1)

        return self.net(z)


class ReconstructionDecoder(nn.Module):

    def __init__(self,
                 latent_dim: int,
                 state_dim: int,
                 hid_dim: int,
                 activation: str = 'relu'):
        super().__init__()
        self.net = MLP(
            in_dim=latent_dim,
            hid_dim=hid_dim,
            out_dim=state_dim,
            n_hid_layers=3,
            activation=activation,
            out_activation='none',
        )

    def forward(self, x):
        return self.net(x)


class CoordConv(torch.nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        width,
    ):
        """Coordinate Convolution

        Args:
            in_channels (_type_): _description_
            out_channels (_type_): _description_
            kernel_size (_type_): _description_
            stride (_type_): _description_
            padding (_type_): _description_
            width (_type_): _description_
        """
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels + 2,
            out_channels,
            kernel_size,
            stride,
            padding,
        )
        self.coords = torch.empty((2, width, width))
        x_coord = torch.arange(width)[None].repeat(width, 1) / width
        self.coords[0] = x_coord
        self.coords[1] = x_coord.T
        # self.register_buffer("coords_const", self.coords)

    def forward(self, x):
        coords = torch.stack([self.coords] * len(x)).to(x.device)
        x = torch.concat((x, coords), dim=-3)  # (B, C, H, W)
        return self.conv(x)


class PretrainedImageEncoder(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.scale_factor = 0.18215

        from diffusers.models import AutoencoderKL
        self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
        for param in self.vae.parameters():
            param.requires_grad = False  # Freeze parameter

    def forward(self, x):
        x = x.permute(0, 3, 1, 2) / 255.  # to channel first
        x = self.vae.encode(
            x).latent_dist.mode() * self.scale_factor  # or sample
        return torch.flatten(x, start_dim=1)


class ConvEncoder(torch.nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_dim: int,
        pretrained: bool = True,
        coord_conv: bool = False,
    ):
        super().__init__()
        self.pretrained = pretrained
        if pretrained:
            self.conv_net = PretrainedImageEncoder()
        else:
            self.conv_net = nn.Sequential(  # (B, 3, 128, 128)
                CoordConv(in_channels, 64, 3, 2, 1, 128) \
                    if coord_conv else nn.Conv2d(in_channels, 64, 3, 2, 1),
                nn.GELU(),  # (B, 64, 64, 64)
                nn.Conv2d(64, 64, 3, 2, 1),
                nn.GELU(),  # (B, 64, 32, 32)
                nn.Conv2d(64, 128, 3, 2, 1),
                nn.GELU(),  # (B, 128, 16, 16)
                nn.Conv2d(128, 128, 3, 2, 1),
                nn.GELU(),  # (B, 128, 8, 8)
                nn.Conv2d(128, 128, 3, 2, 1),
                nn.GELU(),  # (B, 128, 4, 4)
            )

        self.linear = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, out_dim),
        )

    def forward(self, x):
        assert len(x.shape) == 4
        if not self.pretrained:
            x = x.permute(0, 3, 1, 2) / 255.  # to channel first
        x = self.conv_net(x)
        ret = self.linear(x)

        return ret


class ConvDecoder(torch.nn.Module):

    def __init__(
        self,
        image_latent_dim: int,
    ):
        super().__init__()

        self.linear = nn.Sequential(
            # nn.Linear(image_latent_dim, 1024),
            # nn.GELU(),
            # nn.Linear(1024, 128*4*4),
            nn.Linear(image_latent_dim, 128 * 4 * 4),
            nn.GELU(),
            nn.Unflatten(1, (128, 4, 4)),
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 3, 2, 1, 1),
            nn.GELU(),
            nn.ConvTranspose2d(128, 128, 3, 2, 1, 1),
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
            nn.GELU(),
            nn.ConvTranspose2d(64, 64, 3, 2, 1, 1),
            nn.GELU(),
            nn.ConvTranspose2d(64, 3, 3, 2, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.linear(x)
        x = self.deconv(x)

        return x.permute(0, 2, 3, 1)
