#!/usr/bin/env python3

import argparse
import copy
import os
from collections import defaultdict

import coolname
import numpy as np
import pandas as pd
import paths
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics.functional as MF
import utils
from search import configs, darts, nas_bench
from torch.utils import data
from torchvision import datasets, transforms


class MIDASSearch(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()

        self.automatic_optimization = False

        cfg = utils.rec_update(copy.deepcopy(configs.default_cfg), cfg)
        if "uname" not in cfg:
            cfg["uname"] = f"{cfg['name']}-{'_'.join(coolname.generate(2))}"
        cfg["training"]["batches_per_epoch"] = int(
            50_000 / 2 / cfg["training"]["batch_size"]
        )
        self.save_hyperparameters(cfg)

        print(cfg["uname"])

        if self.hparams["search_space"] == "darts":
            self.net = darts.SuperNet(
                paths.CIFAR_NUM_CLASSES,
                self.hparams["model"],
            )
        elif self.hparams["search_space"] == "nasbench201":
            self.net = nas_bench.SuperNet(
                paths.CIFAR_NUM_CLASSES,
                self.hparams["model"],
            )
        else:
            raise ValueError("Incorrectly specified search space!")

        self.coeffs_dict = defaultdict(list)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, _):
        opt_w, opt_a = self.optimizers()

        def weight_closure():
            img, lbl = batch[0]
            out, _ = self.net(img)
            loss = F.cross_entropy(out, lbl)

            self.log("train/weight_loss", loss, on_step=False, on_epoch=True)
            self.log(
                "train/weight_accuracy",
                MF.accuracy(
                    out.argmax(dim=-1),
                    lbl,
                    task="multiclass",
                    num_classes=paths.CIFAR_NUM_CLASSES,
                ),
                on_step=False,
                on_epoch=True,
            )

            # for prog bar only
            self.log(
                "train_accuracy",
                MF.accuracy(
                    out.argmax(dim=-1),
                    lbl,
                    task="multiclass",
                    num_classes=paths.CIFAR_NUM_CLASSES,
                ),
                logger=False,
                prog_bar=True,
            )
            self.log(
                "train_loss",
                loss,
                logger=False,
                prog_bar=True,
            )

            # ensure logging is on epoch
            self.log(f"step", self.current_epoch, on_step=False, on_epoch=True)

            self.manual_backward(loss)

        def arch_closure():
            img, lbl = batch[1]
            out, _ = self.net(img)
            loss = F.cross_entropy(out, lbl)

            self.log("train/arch_loss", loss, on_step=False, on_epoch=True)
            self.log(
                "train/arch_accuracy",
                MF.accuracy(
                    out.argmax(dim=-1),
                    lbl,
                    task="multiclass",
                    num_classes=paths.CIFAR_NUM_CLASSES,
                ),
                on_step=False,
                on_epoch=True,
            )

            # ensure logging is on epoch
            self.log(f"step", self.current_epoch, on_step=False, on_epoch=True)

            self.manual_backward(loss)

        opt_a.zero_grad()
        for param in self.net.weight_parameters():
            param.requires_grad = False
        for param in self.net.arch_parameters():
            param.requires_grad = True
        arch_closure()
        opt_a.step()

        opt_w.zero_grad()
        for param in self.net.weight_parameters():
            param.requires_grad = True
        for param in self.net.arch_parameters():
            param.requires_grad = False
        weight_closure()
        opt_w.step()

    def validation_step(self, batch, batch_idx):
        img, lbl = batch

        out, coeffs_dict = self.net(img, collect_stats=COLLECT_STATS)
        if COLLECT_STATS:
            batch_size = img.shape[0]
            coeffs_dict["class"] = [lbl[c].item() for c in coeffs_dict["sample_id"]]
            coeffs_dict["sample_id"] = [
                c + batch_size * batch_idx for c in coeffs_dict["sample_id"]
            ]
            for key, val in coeffs_dict.items():
                self.coeffs_dict[key].extend(val)
        loss = F.cross_entropy(out, lbl)

        self.log("val/loss", loss)
        self.log(
            "val/accuracy",
            MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
        )

        # for checkpoints
        self.log(
            "accuracy",
            MF.accuracy(
                out.argmax(dim=-1),
                lbl,
                task="multiclass",
                num_classes=paths.CIFAR_NUM_CLASSES,
            ),
            prog_bar=True,
        )

        # ensure logging is on epoch
        self.log(f"step", self.current_epoch)

    def get_progress_bar_dict(self):
        # don't show the loss
        items = super().get_progress_bar_dict()
        items.pop("loss", None)
        return items

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.net.weight_parameters(),
            self.hparams["weights_optim"]["lr"],
            momentum=self.hparams["weights_optim"]["momentum"],
            weight_decay=self.hparams["weights_optim"]["weight_decay"],
        )

        arch_optimizer = torch.optim.AdamW(
            self.net.arch_parameters(),
            lr=self.hparams["arch_optim"]["lr"],
            betas=(0.5, 0.999),
            weight_decay=self.hparams["arch_optim"]["weight_decay"],
        )

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.hparams["training"]["sched_epochs"],
                eta_min=self.hparams["weights_optim"]["min_lr"],
            ),
            "interval": "epoch",
        }

        # for logging with LearningRateMonitor
        scheduler["name"] = "train/weight_lr"

        return [optimizer, arch_optimizer], [scheduler]

    def on_train_epoch_end(self):
        self.lr_schedulers().step()

    def on_validation_epoch_end(self):
        if COLLECT_STATS:
            experiment_name = (
                f"{self.trainer.logger.name}_{self.trainer.logger.version}"
            )
            pd.DataFrame(self.coeffs_dict).to_parquet(f"{experiment_name}.parquet")

    def prepare_data(self):
        self.train_dataset = datasets.CIFAR10(
            root=paths.cifar_dir,
            train=True,
            transform=transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(paths.CIFAR_MEAN, paths.CIFAR_STD),
                    utils.Cutout(self.hparams["dataset"]["cutout_size"]),
                ]
            ),
            download=True,
        )

        self.val_dataset = datasets.CIFAR10(
            root=paths.cifar_dir,
            train=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(paths.CIFAR_MEAN, paths.CIFAR_STD),
                ]
            ),
            download=True,
        )

        if "split" not in self.hparams:
            n = len(self.train_dataset)
            perm = np.random.permutation(n)
            self.hparams["split"] = (perm[: n // 2], perm[n // 2 :])

    def train_dataloader(self):
        train_dataloader = data.DataLoader(
            dataset=utils.ConcatDataset(
                data.Subset(self.train_dataset, self.hparams["split"][0]),
                data.Subset(self.train_dataset, self.hparams["split"][1]),
            ),
            batch_size=self.hparams["training"]["batch_size"],
            shuffle=True,
            drop_last=True,
            num_workers=4,
            pin_memory=True,
        )
        return train_dataloader

    def stats_dataloader(self):
        self.prepare_data()
        n = len(self.train_dataset)
        perm = np.random.permutation(n)[: n // 10]  # take 10% of the training set
        stats_dataloader = data.DataLoader(
            dataset=data.Subset(self.train_dataset, perm),
            batch_size=self.hparams["training"]["batch_size"],
            drop_last=True,
            num_workers=4,
            pin_memory=True,
        )
        return stats_dataloader

    def val_dataloader(self):
        val_dataloader = data.DataLoader(
            dataset=self.val_dataset,
            batch_size=self.hparams["training"]["batch_size"],
            shuffle=False,
            drop_last=True,
            num_workers=4,
            pin_memory=True,
        )
        return val_dataloader


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", help="Name of a config for the current experiment.")
    parser.add_argument(
        "--resume",
        help="Resume an experiment from a checkpoint file. "
        "Please don't specify the config when resuming.",
    )
    parser.add_argument(
        "--name",
        help="custom experiment name instead of the one generated from config.",
    )
    parser.add_argument("--seed", help="", default=1)

    args = parser.parse_args()

    utils.set_seed(seed=int(args.seed))

    if bool(args.config) == bool(args.resume):
        print("Exactly one of --config and --resume can be active at once!")
        exit(-1)

    # new experiment
    if bool(args.config):
        cfg = utils.build_config(configs, args.config.split(","))
        if bool(args.name):
            cfg["name"] = args.name
        model = MIDASSearch(cfg)

    # resume
    if bool(args.resume):
        model = MIDASSearch.load_from_checkpoint(args.resume)
    ckpt_path = args.resume

    # args from hparams handling
    for k, v in model.hparams["args"].items():
        setattr(args, k, v)

    # configure logger and checkpoint callback
    tb_logger = pl.loggers.TensorBoardLogger(
        save_dir=args.default_save_path,
        name="",
        version=model.hparams["uname"],
    )

    os.makedirs(tb_logger.log_dir, exist_ok=True)

    args.logger = tb_logger
    args.callbacks = [
        pl.callbacks.ModelCheckpoint(
            dirpath=tb_logger.log_dir,
            filename="{epoch:02d}",
            save_top_k=0,
            save_last=True,
        ),
        pl.callbacks.ModelCheckpoint(
            dirpath=tb_logger.log_dir,
            filename="{epoch:02d}",
            save_top_k=-1,
            every_n_epochs=10,
        ),
        pl.callbacks.LearningRateMonitor(),
    ]

    args_dict = vars(args)
    args_dict.pop("config")
    args_dict.pop("resume")
    args_dict.pop("name")
    args_dict.pop("seed")
    args_dict.pop("default_save_path")

    # create the trainer
    trainer = pl.Trainer(**args_dict)

    global COLLECT_STATS

    COLLECT_STATS = False
    trainer.fit(model, ckpt_path=ckpt_path)

    trainer.validate(model, ckpt_path=ckpt_path)

    COLLECT_STATS = True

    trainer.validate(model, ckpt_path=ckpt_path, dataloaders=model.stats_dataloader())
