#!/usr/bin/env python3

import argparse
import copy
import os

import coolname
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics.functional as MF
from torch.utils import data
from torchvision import datasets, transforms

import decode
import paths
import utils
from retrain import configs, net


class DARTSRetrain(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        cfg = utils.rec_update(copy.deepcopy(configs.default_cfg), cfg)
        if "uname" not in cfg:
            cfg["uname"] = f"{cfg['name']}-{'_'.join(coolname.generate(2))}"
        cfg["training"]["batches_per_epoch"] = int(
            50_000 / cfg["training"]["batch_size"]
        )

        self.save_hyperparameters(cfg)

        self.hparams["model"]["cells"] = decode.Decoder(
            self.hparams["decoding"],
        ).decode(self.hparams["decoding"]["type"])

        print("Architecture:")
        for cell in self.hparams["model"]["cells"]:
            print(cell)

        self.hparams["uname"] += (
            "-" + self.hparams["decoding"]["stats_df_path"].split("-")[-1][:-8]
        )

        self.net = net.Net(
            paths.CIFAR_NUM_CLASSES,
            self.hparams["model"],
        )

        print(self.hparams["uname"])

        self.log("params", sum(p.numel() for p in self.net.cells.parameters()))

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        self.net.drop_path = self.hparams["training"]["drop_path_start"] + (
            (
                self.hparams["training"]["drop_path_end"]
                - self.hparams["training"]["drop_path_start"]
            )
            * self.current_epoch
            / self.hparams["training"]["epochs"]
        )

        img, lbl = batch
        out, aux_out = self.net(img)
        base_loss = F.cross_entropy(out, lbl)
        aux_loss = F.cross_entropy(aux_out, lbl)
        loss = base_loss + aux_loss * self.hparams["training"]["aux_towers"]

        self.log("train/loss", base_loss, on_step=False, on_epoch=True)
        self.log("train/aux_loss", aux_loss, on_step=False, on_epoch=True)
        self.log("train/combo_loss", loss, on_step=False, on_epoch=True)
        self.log(
            "train/accuracy",
            MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
            on_step=False,
            on_epoch=True,
        )
        self.log(
            "train/aux_accuracy",
            MF.accuracy(
                aux_out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
            on_step=False,
            on_epoch=True,
        )

        # for prog bar only
        self.log(
            "train_accuracy",
            MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
            logger=False,
            prog_bar=True,
        )

        # ensure logging is on epoch
        self.log(f"step", self.current_epoch, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, _):
        img, lbl = batch

        out, _ = self.net(img)
        loss = F.cross_entropy(out, lbl)

        self.log("val/loss", loss)
        self.log(
            "val/accuracy",
            MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
        )
        self.log(
            "val/error",
            1
            - MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
        )

        # for checkpoints
        self.log(
            "accuracy",
            MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
            prog_bar=True,
        )

        # ensure logging is on epoch
        self.log(f"step", self.current_epoch, on_step=False, on_epoch=True)

    def get_progress_bar_dict(self):
        # don't show the loss
        items = super().get_progress_bar_dict()
        items.pop("loss", None)
        return items

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.net.parameters(),
            self.hparams["optim"]["lr"],
            momentum=self.hparams["optim"]["momentum"],
            weight_decay=self.hparams["optim"]["weight_decay"],
        )

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.hparams["training"]["epochs"],
                eta_min=self.hparams["optim"]["min_lr"],
            ),
        }

        # for logging with LearningRateMonitor
        scheduler["name"] = "train/weight_lr"

        return [optimizer], [scheduler]

    def prepare_data(self):
        self.train_dataset = datasets.CIFAR10(
            root=paths.cifar_dir,
            train=True,
            transform=transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(paths.CIFAR_MEAN, paths.CIFAR_STD),
                    utils.Cutout(self.hparams["dataset"]["cutout_size"]),
                ]
            ),
            download=True,
        )

        self.val_dataset = datasets.CIFAR10(
            root=paths.cifar_dir,
            train=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(paths.CIFAR_MEAN, paths.CIFAR_STD),
                ]
            ),
            download=True,
        )

    def train_dataloader(self):
        train_dataloader = data.DataLoader(
            dataset=self.train_dataset,
            batch_size=self.hparams["training"]["batch_size"],
            shuffle=True,
            drop_last=True,
            num_workers=8,
            pin_memory=True,
        )
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = data.DataLoader(
            dataset=self.val_dataset,
            batch_size=self.hparams["training"]["batch_size"],
            shuffle=False,
            drop_last=True,
            num_workers=8,
            pin_memory=True,
        )
        return val_dataloader


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", help="Name of a config for the current experiment.")
    parser.add_argument(
        "--resume",
        help="Resume an experiment from a checkpoint file. "
        "Please don't specify the config when resuming.",
    )
    args = parser.parse_args()

    if bool(args.config) == bool(args.resume):
        print("Exactly one of --config and --resume can be active at once!")
        exit(-1)

    # new experiment
    if bool(args.config):
        cfg = utils.build_config(configs, args.config.split(","))
        model = DARTSRetrain(cfg)

    # resume
    if bool(args.resume):
        model = DARTSRetrain.load_from_checkpoint(args.resume)

    # args from hparams handling
    for k, v in model.hparams["args"].items():
        setattr(args, k, v)

    # configure logger and checkpoint callback
    tb_logger = pl.loggers.TensorBoardLogger(
        save_dir=args.default_save_path,
        name="",
        version=model.hparams["uname"],
    )

    os.makedirs(tb_logger.log_dir, exist_ok=True)

    args.logger = tb_logger
    args.callbacks = [
        pl.callbacks.ModelCheckpoint(
            dirpath=tb_logger.log_dir,
            filename="{epoch:02d}",
            save_top_k=0,
            save_last=True,
        ),
        pl.callbacks.ModelCheckpoint(
            dirpath=tb_logger.log_dir,
            filename="{epoch:02d}",
            save_top_k=-1,
            every_n_epochs=50,
        ),
        pl.callbacks.LearningRateMonitor(),
    ]

    args_dict = vars(args)
    args_dict.pop("config")
    args_dict.pop("resume")
    args_dict.pop("default_save_path")

    trainer = pl.Trainer(**args_dict)

    trainer.fit(model)

    trainer.validate(model)
