from argparse import ArgumentParser

import pytorch_lightning as pl
import torchmetrics
from torch import nn, optim
from torch.utils.data import DataLoader

from sde.datasets import ModuloDataset
from sde.models import MultiLayerFC


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--binary', action='store_true')
    parser.add_argument('--num-hidden-layers', default=5, type=int)
    parser.add_argument('--num-classes', default=30, type=int)
    parser.add_argument('--max-value', default=1000, type=int)
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument('--eta-min', default=1e-5, type=float)
    parser.add_argument('--num-epochs', default=300000, type=int)
    parser.add_argument('--devices', nargs='+', default=[1], type=int)

    return parser.parse_args()


def weight_init(model):
    for name, param in model.named_parameters():
        if 'weight' in name:
            nn.init.sparse(param, 0.5, 5)
        elif 'bias' in name:
            nn.init.constant_(param, 0.0)


class ModuloClassifier(pl.LightningModule):
    # this class learns to memorize the hash table
    def __init__(
            self,
            num_hidden_layers: int,
            num_classes: int,
            num_epochs,
            lr,
            eta_min,
            rank_increase=True,
            rank_to_increase=2000,
            rank_increase_layer=4,
            max_value=None):
        super().__init__()
        self.num_classes = num_classes
        self.lr = lr
        self.eta_min = eta_min
        self.num_epochs = num_epochs
        self.max_value = max_value
        self.rank_increaser = MultiLayerFC(
            1, rank_increase_layer, rank_to_increase, activation='sigmoid') if rank_increase else None

        # init rank increaser
        weight_init(self.rank_increaser)

        true_input_size = rank_to_increase if rank_increase else 1
        self.classifier = MultiLayerFC(true_input_size, num_hidden_layers, self.num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)

        # TODO save hyperparameter for easy reload

    def forward(self, x):
        if self.rank_increaser is None:
            pass
        else:
            x = self.rank_increaser((x / self.max_value) - 0.5)
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        """ only training step is needed, because we overfit the hash table dataset """
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)

        # logging to tensorboard by default
        self.log("training_loss", loss, prog_bar=True)

        # log accuracy
        self.train_acc(y_hat, y)
        self.log("train_acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.classifier.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.num_epochs, eta_min=self.eta_min)
        return [optimizer], [scheduler]


def train(
        num_hidden_layers: int,
        num_classes: int,
        num_epochs=30,
        devices=2,
        lr=1e-4,
        eta_min=1e-6,
        dataset=None,
        max_value=None):
    classifier = ModuloClassifier(
        num_hidden_layers=num_hidden_layers,
        num_classes=num_classes,
        num_epochs=num_epochs,
        lr=lr,
        eta_min=eta_min,
        max_value=max_value)
    assert dataset is not None, "Please provide a valid dataset root"
    dataloader = DataLoader(dataset, batch_size=256, num_workers=10, persistent_workers=True, shuffle=True)
    trainer = pl.Trainer(max_epochs=num_epochs, accelerator='gpu', devices=devices)
    trainer.fit(model=classifier, train_dataloaders=dataloader)


def main(arguments):
    dataset = ModuloDataset(arguments.max_value, arguments.num_classes, binary=arguments.binary)

    train(
        num_hidden_layers=arguments.num_hidden_layers,
        num_classes=arguments.num_classes,
        num_epochs=arguments.num_epochs,
        devices=arguments.devices,
        lr=arguments.lr,
        eta_min=arguments.eta_min,
        dataset=dataset,
        max_value=arguments.max_value,
    )


if __name__ == "__main__":
    args = parse_args()
    print("Run with arguments")
    for key, val in vars(args).items():
        print(f"{key}: {val}")

    main(args)
