"""
This module is used for the resampling of positives in CIFAR10
to reproduce [1]

[1]: https://openreview.net/forum?id=rJzLciCqKm
"""

from typing import Any, Dict, List, Tuple

import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from sklearn.model_selection import KFold, train_test_split
from torch import Tensor, optim
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from puupl.lib.architectures import CNN
from puupl.lib.postprocessing import TemperatureScaler


class Classifier(pl.LightningModule):
    # pylint: disable=arguments-differ

    def __init__(self, x_train: Tensor, y_train: Tensor, x_val: Tensor, y_val: Tensor):
        super().__init__()

        self.net = CNN()
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.batch_size = 8192

        self.x_train = x_train
        self.y_train = y_train
        self.x_val = x_val
        self.y_val = y_val

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        return self.net(x)

    def training_step(  # type: ignore[override]
            self, batch: Tuple[Tensor, Tensor], idx: int) -> Tensor:
        x, y = batch
        phat = self.net(x)
        return binary_cross_entropy_with_logits(phat, y)

    def validation_step(  # type: ignore[override]
            self, batch: Tuple[Tensor, Tensor], idx: int) -> Tuple[Tensor, Tensor, Tensor]:
        x, y = batch
        phat = self.net(x)
        loss = binary_cross_entropy_with_logits(phat, y)
        return phat, y, loss.view(-1)

    def validation_epoch_end(  # type: ignore[override]
            self, outputs: List[Tuple[Tensor, Tensor, Tensor]]) -> None:
        ps, ys, ls = map(torch.cat, zip(*outputs))
        val_acc = torch.mean(((ps > 0) == (ys > 0.5)).float())
        val_loss = torch.mean(ls)

        tqdm.write(f'val_loss: {val_loss.item():.4f} - val acc: {val_acc.item():.4f}')
        self.log('val_loss', val_loss)
        self.log('val_acc', val_acc)

    def configure_optimizers(self) -> Dict[str, Any]:
        opt = optim.Adam(
            self.net.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        sched = optim.lr_scheduler.ReduceLROnPlateau(
            opt, eps=1e-6,
        )

        return {
            'optimizer': opt,
            'lr_scheduler': sched,
            'monitor': 'val_loss',
        }

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=TensorDataset(self.x_train, self.y_train),
            batch_size=self.batch_size,
            num_workers=4, shuffle=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=TensorDataset(self.x_val, self.y_val),
            batch_size=2 * self.batch_size,
            num_workers=4, shuffle=False,
        )


def train_resampling_model() -> None:
    trainset = torchvision.datasets.CIFAR10(
        root='data/cifar10', train=True, download=True, transform=None)

    x_train = trainset.data.swapaxes(1, 3)
    t_train = trainset.targets
    pos_classes = [0, 1, 8, 9]
    t_train = torch.tensor([t in pos_classes for t in t_train]).float()

    x_mean, x_std = x_train.mean(axis=(0, 1, 2)), x_train.std(axis=(0, 1, 2))
    x_train = (x_train - x_mean) / x_std

    x_train = torch.tensor(x_train).float()
    t_train = torch.tensor(t_train).float()

    preds = torch.zeros(len(t_train))
    kf = KFold(5, shuffle=True, random_state=87134135)
    for i, (train_idx, test_idx) in enumerate(kf.split(x_train, t_train)):
        xt, xv, yt, yv = train_test_split(x_train[train_idx], t_train[train_idx],
                                          test_size=0.1, random_state=193597195 + i)

        print('Starting fold', i)

        net = Classifier(xt, yt, xv, yv)
        trainer = pl.Trainer(
            gpus=1 if torch.cuda.is_available() else 0,
            max_epochs=150,
            auto_lr_find=True,
            fast_dev_run=False,
            callbacks=[EarlyStopping(monitor='val_acc', patience=10)],
        )

        print('Tuning...')
        trainer.tune(net)
        print('Final learning rate:', net.learning_rate)
        trainer.fit(net)

        net.cpu()
        with torch.no_grad():
            ts = TemperatureScaler().fit(net(xv.cpu()), None, yv.cpu())  # type: ignore
            preds[test_idx] = ts.scale(net(x_train[test_idx].cpu()))

    print('Final accuracy', torch.mean(
        ((preds > 0.5) == (t_train > 0.5)).float()
    ).item())

    np.save(file='data/cifar10/resampling_predictions.npy', arr=preds)


if __name__ == '__main__':
    train_resampling_model()
