from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from puupl.lib.utils import ConfigurationException


class Architecture(nn.Module, ABC):
    """
    Neural network that is trained on PU data.
    """

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:
        """
        Perform the forward pass used for training.
        You can return a tensor of arbitrary size
-       with samples on the last dimension; they
        will be averaged after a sigmoid to get
        per-sample predictions.
        """

    @abstractmethod
    def sample_predictions(self, datagen: DataLoader,
                           dest_device: Optional[Union[str, torch.device]] = None,
                           ) -> Tensor:
        """
        Sample predictions used to compute uncertainty
        for pseudo-labeling. Return one or more logits
        per sample. Samples are on the last dimension.
        YOU need to take care of setting the model to
        training mode if necessary.
        Also, if provided, move predictions to the given
        device.
        """

    @abstractmethod
    def reset_model(self) -> None:
        """
        Re-initialize the weights of the neural network
        """

    @property
    @abstractmethod
    def device(self) -> torch.device:
        """
        Return the device containing the parameters
        """


class MLP(Architecture):
    def __init__(self, input_shape: int, layer_shape: int, n_hidden_layers: int,
                 batchnorm: bool = True, pdrop: Optional[float] = None,
                 reset_to_same_weights: bool = False,
                 mc_dropout_pl_samples: Optional[int] = None,
                 **kwargs: Any) -> None:

        super().__init__()
        self.input_shape = input_shape
        self.layer_shape = layer_shape
        self.n_hidden_layers = n_hidden_layers
        self.reset_to_same_weights = reset_to_same_weights
        self.pdrop = pdrop
        self.mc_dropout_pl_samples = mc_dropout_pl_samples
        if self.mc_dropout_pl_samples is not None and self.pdrop is None:
            raise ConfigurationException(
                'specify both mc_dropout_pl_samples and pdrop for MC dropout'
            )

        # build model architecture
        def dense(idx: int, din: int, dout: int, batchnorm: bool,
                  pdrop: Optional[float] = None, bias: bool = False) -> List[nn.Module]:
            # Kyrio et al to not use biases and do use batchnorm between linear layers
            ls: List[nn.Module] = []

            if pdrop is not None and idx > 0:
                ls.append(nn.Dropout(pdrop))

            ls.append(nn.Linear(in_features=din, out_features=dout, bias=bias))

            if batchnorm:
                ls.append(nn.BatchNorm1d(num_features=dout))

            return ls

        layers = dense(0, input_shape, layer_shape, batchnorm, self.pdrop)
        for i in range(n_hidden_layers):
            layers.extend(dense(i + 1, layer_shape, layer_shape, batchnorm, self.pdrop))
            layers.append(nn.ReLU())

        # final layer with bias and without batchnorm
        layers.extend(dense(i + 1, layer_shape, 1, batchnorm=False, bias=True))

        self.model = nn.Sequential(*layers)
        self.model_state: Optional[Dict[str, Tensor]] = None
        if self.reset_to_same_weights:
            self.model_state = self.model.state_dict()

    @property
    def device(self) -> torch.device:
        return next(self.model.parameters()).device

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x.view(len(x), self.input_shape)).view(-1)

    def sample_predictions(self, datagen: DataLoader,
                           dest_device: Optional[Union[str, torch.device]] = None,
                           ) -> Tensor:
        dest_device = dest_device or self.device
        if self.mc_dropout_pl_samples is None:
            return torch.stack([
                self.model(
                    batch['x'].to(self.device).view(-1, self.input_shape)
                ).view(-1).to(dest_device)
                for batch in datagen
            ])

        # enable dropout
        old_train = self.model.training
        for layer in self.model.children():
            if isinstance(layer, nn.Dropout):
                layer.train(True)

        # get samples from dropout
        samples = []
        for batch in datagen:
            oof = torch.stack([
                self.model(
                    batch['x'].to(self.device).view(-1, self.input_shape)
                ).view(-1).to(dest_device)
                for _ in range(self.mc_dropout_pl_samples)
            ])
            samples.append(oof)

        # reset dropout to old state
        for layer in self.model.children():
            if isinstance(layer, nn.Dropout):
                layer.train(old_train)

        return torch.cat(samples, dim=1)

    def reset_model(self) -> None:
        def weight_reset(m: nn.Module) -> None:
            reset_parameters = getattr(m, "reset_parameters", None)
            if callable(reset_parameters):
                reset_parameters()

        print('I am re-setting model weights to train again from scratch')
        if self.reset_to_same_weights:
            assert self.model_state is not None
            self.model.load_state_dict(self.model_state)  # type: ignore
        else:
            self.model.apply(weight_reset)


class CNN(Architecture):
    def __init__(
        self,
        channels_in: int = 3,
        n_conv_blocks: int = 3,
        conv_block_size: int = 3,
        first_conv_block_channels: int = 32,
        conv_block_channel_growth_factor: int = 2,
        after_conv_flatten_filters: Optional[int] = None,
        head_units: int = 96,
        head_layers: int = 2,
        pdrop: Optional[float] = None,
        mc_dropout_pl_samples: Optional[int] = None,
        reset_to_same_weights: bool = False,
        **kwargs: Any,
    ) -> None:

        super().__init__()

        self.channels_in = channels_in
        self.n_conv_blocks = n_conv_blocks
        self.conv_block_size = conv_block_size
        self.first_conv_block_channels = first_conv_block_channels
        self.conv_block_channel_growth_factor = conv_block_channel_growth_factor
        self.after_conv_flatten_filters = after_conv_flatten_filters
        self.reset_to_same_weights = reset_to_same_weights
        self.head_units = head_units
        self.head_layers = head_layers
        self.pdrop = pdrop

        self.mc_dropout_pl_samples = mc_dropout_pl_samples
        if self.mc_dropout_pl_samples is not None and self.pdrop is None:
            raise ValueError('specify both mc_dropout_pl_samples and pdrop for MC dropout')

        self.model = self._build_model()
        print(self.model)
        self.model_state: Optional[Dict[str, Tensor]] = None
        if self.reset_to_same_weights:
            self.model_state = self.model.state_dict()

    @staticmethod
    def get_conv_layer(
            in_channels: int, out_channels: int, kernel_size: int,
            padding: int, stride: int, pdrop: Optional[float],
    ) -> List[nn.Module]:
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                         kernel_size=kernel_size, padding=padding, stride=stride)
        bn = nn.BatchNorm2d(num_features=out_channels)
        activ = nn.ReLU()

        if pdrop:
            drop = nn.Dropout(pdrop)
            return [conv, drop, bn, activ]
        else:
            return [conv, bn, activ]

    def _build_model(self) -> nn.Module:
        layers = []
        chin, chout = self.channels_in, self.first_conv_block_channels
        for i in range(self.n_conv_blocks):
            for _ in range(self.conv_block_size - 1):
                layers.extend(self.get_conv_layer(chin, chout, 3, 1, 1, self.pdrop))
                chin = chout

            layers.extend(self.get_conv_layer(chin, chout, 3, 1, 2, self.pdrop))
            chin = chout

            if i < self.n_conv_blocks - 1:
                chout = int(chout * self.conv_block_channel_growth_factor)

        if self.after_conv_flatten_filters is not None:
            layers.append(nn.Flatten())
            chout = self.after_conv_flatten_filters
        else:
            class GlobalAveragePooling(nn.Module):
                def forward(self, x: Tensor) -> Tensor:
                    return x.mean(dim=(-1, -2))

            layers.append(GlobalAveragePooling())

        for _ in range(self.head_layers):
            layers.append(nn.Linear(chout, self.head_units))
            layers.append(nn.ReLU())
            chout = self.head_units

        layers.append(nn.Linear(chout, 1))

        return nn.Sequential(*layers)

    @property
    def device(self) -> torch.device:
        return next(self.model.parameters()).device

    def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
        # return logits
        return self.model(x).view(-1)

    def sample_predictions(self, datagen: DataLoader,
                           dest_device: Optional[Union[str, torch.device]] = None,
                           ) -> Tensor:
        dest_device = dest_device or self.device
        if self.mc_dropout_pl_samples is None:
            samples = []
            for xx in datagen:
                samples.append(
                    self.model(xx['x'].to(self.device)).view(-1).to(dest_device)
                )
            return torch.cat(samples, dim=0)

        # enable dropout
        old_train = self.model.training
        for layer in self.model.children():
            if isinstance(layer, nn.Dropout):
                layer.train(True)

        # get samples from dropout
        samples = []
        for xx in datagen:
            oof = torch.stack([
                self.model(xx['x'].to(self.device)).view(-1).to(dest_device)
                for _ in range(self.mc_dropout_pl_samples)
            ])
            samples.append(oof)

        # reset dropout to old state
        for layer in self.model.children():
            if isinstance(layer, nn.Dropout):
                layer.train(old_train)

        return torch.cat(samples, dim=1)

    def reset_model(self) -> None:
        def weight_reset(m: nn.Module) -> None:
            reset_parameters = getattr(m, "reset_parameters", None)
            if callable(reset_parameters):
                reset_parameters()

        print('I am re-setting model weights to train again from scratch')
        if self.reset_to_same_weights:
            assert self.model_state is not None
            self.model.load_state_dict(self.model_state)  # type: ignore
        else:
            self.model.apply(weight_reset)


class KiryoCNN(Architecture):
    """
    9 layer CNN as used as baseline architecture in Kiryo 2017
    Chainer ref. implementation: https://github.com/kiryor/nnPUlearning/blob/master/model.py
    """
    def __init__(self,
                 pdrop: Optional[float] = None,
                 mc_dropout_pl_samples: Optional[int] = None,
                 reset_to_same_weights: bool = False, **kwargs: Any):

        super().__init__()

        self.pdrop = pdrop
        self.mc_dropout_pl_samples = mc_dropout_pl_samples
        if self.mc_dropout_pl_samples is not None and self.pdrop is None:
            raise ValueError('specify both mc_dropout_pl_samples and pdrop for MC dropout')

        self.model = self._build_model()
        print(self.model)
        self.model_state: Optional[Dict[str, Tensor]] = None
        self.reset_to_same_weights = reset_to_same_weights
        if self.reset_to_same_weights:
            self.model_state = self.model.state_dict()

    def _build_model(self) -> nn.Module:
        layers = []
        # 1
        layers.extend(CNN.get_conv_layer(3, 96, 3, 1, 1, self.pdrop))
        layers.extend(CNN.get_conv_layer(96, 96, 3, 1, 1, self.pdrop))
        layers.extend(CNN.get_conv_layer(96, 96, 3, 1, 2, self.pdrop))
        # 4
        layers.extend(CNN.get_conv_layer(96, 192, 3, 1, 1, self.pdrop))
        layers.extend(CNN.get_conv_layer(192, 192, 3, 1, 1, self.pdrop))
        layers.extend(CNN.get_conv_layer(192, 192, 3, 1, 2, self.pdrop))
        # 7
        layers.extend(CNN.get_conv_layer(192, 192, 3, 1, 1, self.pdrop))
        layers.extend(CNN.get_conv_layer(192, 192, 1, 0, 1, self.pdrop))
        layers.extend(CNN.get_conv_layer(192, 10, 1, 0, 1, self.pdrop))

        # linear layer part
        layers.extend([
            nn.Flatten(),
            nn.Linear(in_features=640, out_features=1000), nn.ReLU(),
            nn.Linear(in_features=1000, out_features=1000), nn.ReLU(),
            nn.Linear(in_features=1000, out_features=1),
        ])

        return nn.Sequential(*layers)

    @property
    def device(self) -> torch.device:
        return next(self.model.parameters()).device

    def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
        # return logits
        return self.model(x).view(-1)

    def sample_predictions(self, datagen: DataLoader,
                           dest_device: Optional[Union[str, torch.device]] = None,
                           ) -> Tensor:
        dest_device = dest_device or self.device
        if self.mc_dropout_pl_samples is None:
            samples = []
            for xx in datagen:
                samples.append(
                    self.model(xx['x'].to(self.device)).view(-1).to(dest_device)
                )
            return torch.cat(samples, dim=0)

        # enable dropout
        old_train = self.model.training
        for layer in self.model.children():
            if isinstance(layer, nn.Dropout):
                layer.train(True)

        # get samples from dropout
        samples = []
        for xx in datagen:
            oof = torch.stack([
                self.model(xx['x'].to(self.device)).view(-1).to(dest_device)
                for _ in range(self.mc_dropout_pl_samples)
            ])
            samples.append(oof)

        # reset dropout to old state
        for layer in self.model.children():
            if isinstance(layer, nn.Dropout):
                layer.train(old_train)

        return torch.cat(samples, dim=1)

    def reset_model(self) -> None:
        def weight_reset(m: nn.Module) -> None:
            reset_parameters = getattr(m, "reset_parameters", None)
            if callable(reset_parameters):
                reset_parameters()

        print('I am re-setting model weights to train again from scratch')
        if self.reset_to_same_weights:
            assert self.model_state is not None
            self.model.load_state_dict(self.model_state)  # type: ignore
        else:
            self.model.apply(weight_reset)


class LSTM(Architecture):
    def __init__(self, input_size: int, lstm_hidden_size: int, lstm_num_layers: int,
                 lstm_dropout: Optional[float], bidirectional: bool,
                 mlp_dropout: Optional[float], mlp_layers: int,
                 mlp_batchnorm: bool, mlp_units: int, mc_dropout_pl_samples: Optional[int],
                 reset_to_same_weights: bool, **kwargs: Any):
        super().__init__()

        self.mc_dropout_pl_samples = mc_dropout_pl_samples
        self.bidirectional = bidirectional

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            dropout=lstm_dropout if lstm_dropout is not None else 0,
            bidirectional=bidirectional,
            batch_first=False,
        )

        def mlp_layer(fin: int, fout: int) -> List[nn.Module]:
            ll: List[nn.Module] = []
            if mlp_dropout is not None:
                ll.append(nn.Dropout(mlp_dropout))

            ll.append(nn.Linear(fin, fout))

            if mlp_batchnorm:
                ll.append(nn.BatchNorm1d(fout))
            ll.append(nn.ReLU())
            return ll

        if mlp_layers > 0:
            head_layers = mlp_layer(
                lstm_hidden_size * (2 if self.bidirectional else 1),
                mlp_units
            )
            for _ in range(mlp_layers - 1):
                head_layers.extend(mlp_layer(mlp_units, mlp_units))
            head_layers.append(nn.Linear(mlp_units, 1))
        else:
            head_layers = [nn.Linear(
                lstm_hidden_size * (2 if self.bidirectional else 1), 1
            )]

        self.head = nn.Sequential(*head_layers)

        self.reset_to_same_weights = reset_to_same_weights
        if self.reset_to_same_weights:
            self.model_state = self.state_dict()

        print(self)

    def forward(self, x: Tensor) -> Tensor:
        _, (hidden, _) = self.lstm(x)

        mlp_in = torch.cat(
            [hidden[-2], hidden[-1]],
            dim=1
        ) if self.bidirectional else hidden[-1]

        phat = self.head(mlp_in)

        return phat.view(-1)

    def sample_predictions(self, datagen: DataLoader,
                           dest_device: Optional[Union[str, torch.device]] = None,
                           ) -> Tensor:
        dest_device = dest_device or self.device
        if self.mc_dropout_pl_samples is None:
            return torch.stack([
                self(batch['x'].to(self.device)).to(dest_device)
                for batch in datagen
            ])

        # enable dropout
        old_train = self.training
        for layer in self.children():
            if isinstance(layer, (nn.Dropout, nn.LSTM)):
                layer.train(True)

        # get samples from dropout
        samples = []
        for batch in datagen:
            oof = torch.stack([
                self(batch['x'].to(self.device)).to(dest_device)
                for _ in range(self.mc_dropout_pl_samples)
            ])
            samples.append(oof)

        # reset dropout to old state
        for layer in self.children():
            if isinstance(layer, (nn.Dropout, nn.LSTM)):
                layer.train(old_train)

        return torch.cat(samples, dim=1)

    def reset_model(self) -> None:
        def weight_reset(m: nn.Module) -> None:
            reset_parameters = getattr(m, "reset_parameters", None)
            if callable(reset_parameters):
                reset_parameters()

        print('I am re-setting model weights to train again from scratch')
        if self.reset_to_same_weights:
            assert self.model_state is not None
            self.load_state_dict(self.model_state)  # type: ignore
        else:
            self.apply(weight_reset)

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device


class Ensemble(Architecture):
    def __init__(self, ensemble_size: int, backbone: str, **kwargs: Any) -> None:
        super().__init__()

        self.ensemble = []
        for i in range(ensemble_size):
            emember: Architecture
            if backbone == 'mlp':
                emember = MLP(**kwargs)
            elif backbone == 'cnn':
                emember = CNN(**kwargs)
            elif backbone == 'kiryo_cnn':
                emember = KiryoCNN(**kwargs)
            elif backbone == 'lstm':
                emember = LSTM(**kwargs)
            else:
                raise ValueError(f'unknown backbone {backbone}')

            self.add_module(f'emember_{i}', emember)
            self.ensemble.append(emember)

    @property
    def device(self) -> torch.device:
        return self.ensemble[0].device

    def forward(self, x: Tensor) -> Tensor:
        return torch.stack([emember(x) for emember in self.ensemble])

    def sample_predictions(self, datagen: DataLoader,
                           dest_device: Optional[Union[str, torch.device]] = None,
                           ) -> Tensor:
        samples = []
        for xx in datagen:
            oof = torch.stack([
                emember(xx['x'].to(self.device)).view(-1).to(self.device)
                for emember in self.ensemble
            ])
            samples.append(oof)
        return torch.cat(samples, dim=1)

    def reset_model(self) -> None:
        for emember in self.ensemble:
            emember.reset_model()


architectures_dict = {
    'mlp': MLP,
    'ensemble': Ensemble,
    'cnn': CNN,
    'kiryo_cnn': KiryoCNN,
    'lstm': LSTM,
}


def get_architecture(config: Dict[str, Any]) -> Architecture:
    cls = architectures_dict[config['class']]
    params = config.get('params') or {}
    return cls(**params)
