from typing import Any, Optional

import torch
from torch import nn
import torch.nn.functional as F

# TODO more customisable parameters


def CNN(image_size: int,
        feature_size: int,  # output features
        hidden_layers: int,
        hidden_channels: int,
        image_channels: int = 3,
        only_hidden_layers: bool = False,
        conv2d_kwargs: dict[str, Any] = {'kernel_size': 5,
                                         'stride': 2,
                                         'padding': 2,},
        leaky_relu_kwargs: dict[str, Any] = {'negative_slope': 0.1},
        sigmoid_for_last_layer: bool = True) -> nn.Sequential:
    in_channels = image_channels
    out_channels = hidden_channels
    layers = []
    for _ in range(hidden_layers):
        layers += [nn.Conv2d(in_channels, out_channels, **conv2d_kwargs),
                   nn.BatchNorm2d(out_channels),
                   nn.LeakyReLU(**leaky_relu_kwargs)]

        in_channels = out_channels
        out_channels *= 2
        assert image_size % 2 == 0
        image_size //= 2

    if only_hidden_layers:
        return nn.Sequential(*layers)
    else:
        layers.extend([
            nn.Conv2d(in_channels, feature_size, image_size),
            nn.Flatten(-3),
        ])
        if feature_size == 1 and sigmoid_for_last_layer:
            layers.append(nn.Sigmoid())

        return nn.Sequential(*layers)


def DCNN(image_size: int,
         feature_size: int,  # input features
         hidden_layers: int,
         hidden_channels: int,
         image_channels: int = 3,
         only_hidden_layers: bool = False,
         convt2d_kwargs: dict[str, Any] = {'kernel_size': 5,
                                                    'stride': 2,
                                                    'padding': 2,
                                                    'output_padding': 1,},
         leaky_relu_kwargs: dict[str, Any] = {'negative_slope': 0.1}) -> nn.Sequential:
    multiplier = 2 ** hidden_layers
    assert image_size % multiplier == 0

    channels = hidden_channels * multiplier // 2
    layers = []
    if not only_hidden_layers:
        layers = [
            nn.Unflatten(-1, (feature_size, 1, 1)),
            nn.ConvTranspose2d(feature_size, channels, image_size // multiplier)
        ]
    for i in reversed(range(hidden_layers)):
        out_channels = channels // 2 if i else image_channels
        layers += [nn.BatchNorm2d(channels),
                   nn.LeakyReLU(**leaky_relu_kwargs),
                   nn.ConvTranspose2d(channels, out_channels, **convt2d_kwargs)]

        channels = out_channels

    return nn.Sequential(*layers, nn.Sigmoid())


class CNNGetter(nn.Module):
    def __init__(self,
                 input_size: int,
                 concept_size: int,
                 hidden_layers: int = 4,
                 hidden_channels: int = 32,
                 input_channels: int = 3,
                 sigmoid_for_last_layer: bool = True) -> None:
        super().__init__()

        self.input_size = input_size
        self.concept_size = concept_size
        self.hidden_layers = hidden_layers
        self.hidden_channels = hidden_channels
        self.input_channels = input_channels
        self.sigmoid_for_last_layer = sigmoid_for_last_layer

        self.model = CNN(input_size,
                         concept_size,
                         hidden_layers,
                         hidden_channels,
                         input_channels,
                         sigmoid_for_last_layer=sigmoid_for_last_layer)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(state.unsqueeze(0)).squeeze(0)
        return self.model(state)

    def get(self, state: torch.Tensor) -> torch.Tensor:
        if self.concept_size > 1:
            return self(state).argmax(-1)
        else:
            return self(state)


class DCNNPutter(nn.Module):
    def __init__(self,
                 input_size: int,
                 concept_size: int,
                 latent_size: int,
                 hidden_layers: int = 4,  # for CNN and DCNN
                 hidden_channels: int = 32,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.concept_size = concept_size
        self.latent_size = latent_size
        self.hidden_channels = hidden_channels
        self.input_channels = input_channels

        self.encode = CNN(input_size,
                          latent_size,
                          hidden_layers,
                          hidden_channels,
                          input_channels)

        self.decode = DCNN(input_size,
                           latent_size + concept_size,
                           hidden_layers,
                           hidden_channels,
                           input_channels)

    def latent(self,
               state: torch.Tensor,
               concept: torch.Tensor) -> torch.Tensor:
        encoded_state = self.encode(state)
        # print(f'{encoded_state.shape = }, {concept.shape = }')
        if self.concept_size > 1:
            one_hot_concept = F.one_hot(concept, self.concept_size)
        else:
            one_hot_concept = torch.reshape(concept,
                                            (concept.shape[0], -1))
        return torch.cat((encoded_state, one_hot_concept), dim=-1)

    def forward(self,
                state: torch.Tensor,
                concept: torch.Tensor) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(state.unsqueeze(0),
                                concept.unsqueeze(0)).squeeze(0)
        return self.decode(self.latent(state, concept))

    def put(self, state: torch.Tensor, concept: torch.Tensor) -> torch.Tensor:
        return self(state, concept)


class DCNNPutterWithConceptEncoder(nn.Module):
    def __init__(self,
                 input_size: int,
                 concept_size: int,
                 latent_size: int,
                 hidden_layers: int = 4,  # for CNN and DCNN
                 hidden_channels: int = 32,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.concept_size = concept_size
        self.latent_size = latent_size
        self.hidden_channels = hidden_channels
        self.input_channels = input_channels

        self.encode = CNN(input_size,
                          latent_size,
                          hidden_layers,
                          hidden_channels,
                          input_channels)

        self.decode = DCNN(input_size,
                           latent_size + 1,
                           hidden_layers,
                           hidden_channels,
                           input_channels)

        self.concept_encoder = nn.Sequential(
            nn.Linear(concept_size, latent_size, bias=False),
            nn.LeakyReLU(negative_slope=0.1),
        )

    def latent(self,
               state: torch.Tensor,
               concept: torch.Tensor,
               scratch: Optional[torch.Tensor] = None) -> torch.Tensor:
        encoded_state = self.encode(state)
        encoded_concept = self.concept_encoder(concept)
        latent = encoded_state + encoded_concept

        if scratch is not None:
            batch_size = latent.shape[0]
            latent = torch.cat((latent, scratch.reshape((batch_size, -1))), dim=-1)

        return latent

    def forward(self,
                state: torch.Tensor,
                concept: torch.Tensor,
                scratch: Optional[torch.Tensor] = None) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(
                state.unsqueeze(0),
                concept.unsqueeze(0),
                scratch.unsqueeze(0) if scratch is not None else None).squeeze(0)
        return self.decode(self.latent(state, concept, scratch))

    def put(self,
            state: torch.Tensor,
            concept: torch.Tensor,
            scratch: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self(state, concept, scratch)


class Encoder(nn.Module):
    def __init__(self,
                 input_size: int,
                 latent_size: int,
                 hidden_layers: int = 4,
                 hidden_channels: int = 32,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.latent_size = latent_size
        self.hidden_layers = hidden_layers
        self.hidden_channels = hidden_channels
        self.input_channels = input_channels

        self.encoder = CNN(input_size,
                         latent_size,
                         hidden_layers,
                         hidden_channels,
                         input_channels)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.encoder(data)


class Decoder(nn.Module):
    def __init__(self,
                 input_size: int,
                 latent_size: int,
                 hidden_layers: int = 4,
                 hidden_channels: int = 32,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.latent_size = latent_size
        self.hidden_layers = hidden_layers
        self.hidden_channels = hidden_channels
        self.input_channels = input_channels

        self.decoder = DCNN(input_size,
                            latent_size,
                            hidden_layers,
                            hidden_channels,
                            input_channels)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.decoder(data)


class Discriminator(nn.Module):
    def __init__(self,
                 input_size: int,
                 concept_size: int,
                 hidden_layers: int = 4,
                 hidden_channels: int = 32,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.concept_size = concept_size
        self.hidden_layers = hidden_layers
        self.hidden_channels = hidden_channels
        self.input_channels = input_channels

        self.model = CNN(input_size,
                         concept_size,
                         hidden_layers,
                         hidden_channels,
                         input_channels,
                         only_hidden_layers=True,
                         conv2d_kwargs={'kernel_size': 4,
                                        'stride': 2,
                                        'padding': 1,
                                        'bias': False},
                         leaky_relu_kwargs={'negative_slope': 0.2,
                                            'inplace': True})

        self.model = nn.Sequential(
            self.model,
            nn.Conv2d(self.hidden_channels * (2 ** (self.hidden_layers - 1)), 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.model(data).view(-1)
