from __future__ import annotations

import copy

import torch
from torch import nn
import torch.nn.functional as F

# TODO more customisable parameters


def CNN(image_size: int,
        feature_size: int,  # output features
        hidden_layers: int,
        hidden_channels: int,
        channel_multiplier: int = 1,
        image_channels: int = 3) -> nn.Sequential:
    in_channels = image_channels
    out_channels = hidden_channels
    layers = []
    for _ in range(hidden_layers):
        layers += [nn.Conv2d(in_channels, out_channels, 5, 2, 2),
                   nn.BatchNorm2d(out_channels),
                   nn.LeakyReLU(0.1)]

        in_channels = out_channels
        out_channels *= channel_multiplier

        assert image_size % 2 == 0
        image_size //= 2

    return nn.Sequential(*layers,
                         nn.Conv2d(in_channels, feature_size, image_size),
                         nn.Flatten(-3))


def DCNN(image_size: int,
         feature_size: int,  # input features
         hidden_layers: int,
         hidden_channels: int,
         channel_multiplier: int = 1,
         image_channels: int = 3) -> nn.Sequential:
    multiplier = 2 ** hidden_layers
    assert image_size % multiplier == 0

    channels = hidden_channels * (channel_multiplier ** (hidden_layers - 1))
    layers = [
        nn.Unflatten(-1, (feature_size, 1, 1)),
        nn.ConvTranspose2d(feature_size, channels, image_size // multiplier)
    ]
    for i in reversed(range(hidden_layers)):
        out_channels = channels // channel_multiplier if i else image_channels
        layers += [nn.BatchNorm2d(channels),
                   nn.LeakyReLU(0.1),
                   nn.ConvTranspose2d(channels, out_channels, 5, 2, 2, 1)]

        channels = out_channels

    return nn.Sequential(*layers, nn.Sigmoid())


class AutoEncoder(nn.Module):
    def __init__(self,
                 input_size: int,
                 latent_size: int,
                 hidden_layers: int = 4,
                 hidden_channels: int = 32,
                 channel_multiplier: int = 1,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.latent_size = latent_size
        self.hidden_layers = hidden_layers
        self.hidden_channels = hidden_channels
        self.channel_multiplier = channel_multiplier
        self.input_channels = input_channels

        self.encode = CNN(input_size,
                          latent_size,
                          hidden_layers,
                          hidden_channels,
                          channel_multiplier,
                          input_channels)

        self.decode = DCNN(input_size,
                           latent_size,
                           hidden_layers,
                           hidden_channels,
                           channel_multiplier,
                           input_channels)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(state.unsqueeze(0)).squeeze(0)
        return self.decode(self.encode(state))


class LinearGetter(nn.Module):
    def __init__(self,
                 autoencoder: nn.Module,
                 concept_size: int) -> None:
        super().__init__()

        self.encode = autoencoder.encode

        self.concept_size = concept_size
        self.latent_size = latent_size = autoencoder.latent_size

        self.classify = nn.Linear(latent_size, concept_size)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(state.unsqueeze(0)).squeeze(0)
        return self.classify(self.encode(state))


class LinearPutter(nn.Module):
    def __init__(self,
                 autoencoder: nn.Module,
                 concept_size: int,
                 one_hot_concept: bool = False) -> None:
        super().__init__()

        self.concept_size = concept_size
        self.latent_size = latent_size = autoencoder.latent_size
        self.one_hot_concept = one_hot_concept

        self.encode = autoencoder.encode
        self.decode = copy.deepcopy(autoencoder.decode)
        self.decode.requires_grad_(False)

        self.shift = nn.Linear(concept_size, latent_size)

    def sync_autoencoder(self, autoencoder: nn.Module) -> None:
        self.decode.load_state_dict(autoencoder.decode.state_dict())

    def forward(self,
                concept: torch.Tensor,
                state: torch.Tensor,
                complement: torch.Tensor = None) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(concept.unsqueeze(0),
                                state.unsqueeze(0))[0].squeeze(0), None

        if self.one_hot_concept:
            concept = F.one_hot(concept, self.concept_size).float()

        return self.decode(self.encode(state) + self.shift(concept)), None


class LinearPutterWithComplement(nn.Module):
    def __init__(self,
                 autoencoder: nn.Module,
                 concept_size: int,
                 complement_size: int = 0,
                 one_hot_concept: bool = False) -> None:
        super().__init__()

        self.concept_size = concept_size
        self.latent_size = latent_size = autoencoder.latent_size
        self.complement_size = complement_size
        self.one_hot_concept = one_hot_concept

        self.encode = autoencoder.encode
        self.decode = copy.deepcopy(autoencoder.decode)
        self.decode.requires_grad_(False)

        self.shift = nn.Linear(concept_size + latent_size + complement_size,
                               latent_size + complement_size)

        self.init_complement = nn.Parameter(torch.zeros(complement_size))

    def sync_autoencoder(self, autoencoder: nn.Module) -> None:
        self.decode.load_state_dict(autoencoder.decode.state_dict())

    def forward(self,
                concept: torch.Tensor,
                state: torch.Tensor,
                complement: torch.Tensor = None) -> torch.Tensor:
        if state.ndim == 3:
            cpl = complement.unsqueeze(0) if complement is not None else None
            output, complement = self.forward(concept.unsqueeze(0),
                                              state.unsqueeze(0),
                                              cpl)
            return output.squeeze(0), complement.squeeze(0)

        if complement is None:
            complement = self.init_complement.repeat(len(state), 1)

        if self.one_hot_concept:
            concept = F.one_hot(concept, self.concept_size).float()

        latent = self.encode(state)
        output = self.shift(torch.cat((concept, latent, complement), dim=-1))
        return (self.decode(output[..., :self.latent_size]),
                output[..., self.latent_size:])
