import argparse
import os
import shutil
import sys

import torch
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm, trange

from expman import Experiment

### from model import ResNet, ODENet
### from model import ResNet, ODENet, ODENetSpatial, ODENetSpatialPlusTwoMul
from model import (
    ResNet,
    ODENet,
    ODENetSpatial,
    ODENetSpatialPlusTwoMul,
    ODENetSigmoid,
    ODENetSiLU,
    ODENetSiLUReLU,
    ODENetSiLUSigmoid,
    ODENetVaryingGroup,
    ODENetVaryingNorm,
    ODENetVaryingDrop,
    ODENetFixedDrop,
    ODENetSoftplus,
    ODENetReLUSoftplus,
    ODENetELU,
    ODENetReLUELU,
    ODENetCTNR,
    ODENetCNTR,
    ODENetCNRT,
    ODENetinit1en2,
    ODENetinit1en1,
    ODENetinit1ep0,
    ODENetinit1ep1,
    ODENetinit1ep2,
    ODENetinit1en3,
    ODENetinit1ep3,
    ODENetinit1ep4,
    ODENetzerobiasall,
    ODENetzerobiasconv,
    ODENetzerobiasdense,
    ODENetinit3ep0,
    ODENetinit3ep1,
    ODENetinit3ep2,
    ODENetzerobiasnone,
)
from utils import load_dataset


def save_checkpoint(exp, state, is_best):
    filename = exp.ckpt("last")
    torch.save(state, filename)
    if is_best:
        best_filename = exp.ckpt("best")
        shutil.copyfile(filename, best_filename)


def train(loader, model, optimizer, args):
    model.train()
    optimizer.zero_grad()

    nfe_forward = 0
    nfe_backward = 0

    n_correct = 0
    n_processed = 0
    n_batch_processed = 0

    total_loss = 0

    progress = tqdm(loader)
    for x, y in progress:
        x, y = x.to(args.device), y.to(args.device)
        p = model(x)
        loss = F.cross_entropy(p, y)
        total_loss += loss.item()

        n_correct += (y == p.argmax(dim=1)).sum().item()
        n_processed += y.shape[0]

        nfe_forward += model.nfe(reset=True)

        loss.backward()

        nfe_backward += model.nfe(reset=True)
        n_batch_processed += 1

        if n_batch_processed % args.batch_accumulation == 0:
            optimizer.step()
            optimizer.zero_grad()

        accuracy = n_correct / n_processed
        avg_loss = total_loss / n_batch_processed
        avg_nfe_forward = nfe_forward / n_batch_processed
        avg_nfe_backward = nfe_backward / n_batch_processed

        progress.set_postfix(
            {
                "loss": f"{loss:4.3f}|{avg_loss:4.3f}",
                "acc": f"{n_correct:4d}/{n_processed:4d} ({accuracy:.2%})",
                "NFE-F": f"{avg_nfe_forward:3.1f}",
                "NFE-B": f"{avg_nfe_backward:3.1f}",
            }
        )

    return {
        "loss": avg_loss,
        "acc": accuracy,
        "nfe-f": avg_nfe_forward,
        "nfe-b": avg_nfe_backward,
    }


def evaluate(loader, model, args):
    model.eval()

    nfe_forward = 0

    n_correct = 0
    n_batches = 0
    n_processed = 0

    total_loss = 0

    progress = tqdm(loader)
    for x, y in progress:
        x, y = x.to(args.device), y.to(args.device)
        p = model(x)
        nfe_forward += model.nfe(reset=True)
        loss = F.cross_entropy(p, y, reduction="sum")
        total_loss += loss.item()

        n_correct += (y == p.argmax(dim=1)).sum().item()
        n_processed += y.shape[0]
        n_batches += 1

        logloss = total_loss / n_processed
        accuracy = n_correct / n_processed
        nfe = nfe_forward / n_batches
        metrics = {
            "loss": f"{logloss:4.3f}",
            "acc": f"{n_correct:4d}/{n_processed:4d} ({accuracy:.2%})",
            "nfe": f"{nfe:3.1f}",
        }
        progress.set_postfix(metrics)

    return {"test_loss": logloss, "test_acc": accuracy, "test_nfe": nfe}


def main(args):
    root = "runs_" + args.dataset
    exp = Experiment(
        args, root=root, main="model", ignore=("cuda", "device", "epochs", "resume")
    )

    print(exp)
    if os.path.exists(exp.path_to("log")) and not args.resume:
        print("Skipping ...")
        sys.exit(0)

    train_data, test_data, in_ch, out = load_dataset(args)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)

    common_model_params = dict(
        out=out,
        downsample=args.downsample,
        n_filters=args.filters,
        dropout=args.dropout,
        norm=args.norm,
    )
    if args.model == "odenet":
        model = ODENet(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetspatial":
        model = ODENetSpatial(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetspatialplustwomul":
        model = ODENetSpatialPlusTwoMul(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetsigmoid":
        model = ODENetSigmoid(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetsilu":
        model = ODENetSiLU(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetsilurelu":
        model = ODENetSiLUReLU(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetsilusigmoid":
        model = ODENetSiLUSigmoid(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetvaryinggroup":
        model = ODENetVaryingGroup(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetvaryingnorm":
        model = ODENetVaryingNorm(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetvaryingdrop":
        model = ODENetVaryingDrop(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetfixeddrop":
        model = ODENetFixedDrop(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetsoftplus":
        model = ODENetSoftplus(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetrelusoftplus":
        model = ODENetReLUSoftplus(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetelu":
        model = ODENetELU(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetreluelu":
        model = ODENetReLUELU(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetctnr":
        model = ODENetCTNR(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetcntr":
        model = ODENetCNTR(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetcnrt":
        model = ODENetCNRT(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1en2":
        model = ODENetinit1en2(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1en1":
        model = ODENetinit1en1(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1ep0":
        model = ODENetinit1ep0(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1ep1":
        model = ODENetinit1ep1(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1ep2":
        model = ODENetinit1ep2(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1ep3":
        model = ODENetinit1ep3(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1ep4":
        model = ODENetinit1ep4(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit1en3":
        model = ODENetinit1en3(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit3ep0":
        model = ODENetinit3ep0(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit3ep1":
        model = ODENetinit3ep1(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetinit3ep2":
        model = ODENetinit3ep2(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetzerobiasall":
        model = ODENetzerobiasall(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetzerobiasnone":
        model = ODENetzerobiasnone(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetzerobiasconv":
        model = ODENetzerobiasconv(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    elif args.model == "odenetzerobiasdense":
        model = ODENetzerobiasdense(
            in_ch,
            method=args.method,
            tol=args.tol,
            adjoint=args.adjoint,
            **common_model_params,
        )
    else:
        model = ResNet(in_ch, **common_model_params)

    model = model.to(args.device)
    if args.optim == "sgd":
        optimizer = SGD(
            model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd
        )
    elif args.optim == "adam":
        optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    # print(train_data)
    # print(test_data)
    # print(model)
    # print(optimizer)

    if args.resume:
        ckpt = torch.load(exp.ckpt("last"))
        print("Loaded: {}".format(exp.ckpt("last")))
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optim"])
        start_epoch = ckpt["epoch"] + 1
        best_accuracy = exp.log["test_acc"].max()
        print("Resuming from epoch {}: {}".format(start_epoch, exp.name))
    else:
        metrics = evaluate(test_loader, model, args)
        best_accuracy = metrics["test_acc"]
        start_epoch = 1

    if args.lrschedule == "fixed":
        scheduler = LambdaLR(
            optimizer, lr_lambda=lambda x: 1
        )  # no-op scheduler, just for cleaner code
    elif args.lrschedule == "plateau":
        scheduler = ReduceLROnPlateau(optimizer, mode="max", patience=args.patience)
    elif args.lrschedule == "cosine":
        scheduler = CosineAnnealingLR(
            optimizer, args.lrcycle, last_epoch=start_epoch - 2
        )

    progress = trange(
        start_epoch, args.epochs + 1, initial=start_epoch, total=args.epochs
    )
    for epoch in progress:
        metrics = {"epoch": epoch}

        progress.set_postfix({"Best ACC": f"{best_accuracy:.2%}"})
        progress.set_description("TRAIN")
        train_metrics = train(train_loader, model, optimizer, args)

        progress.set_description("EVAL")
        test_metrics = evaluate(test_loader, model, args)

        is_best = test_metrics["test_acc"] > best_accuracy
        best_accuracy = max(test_metrics["test_acc"], best_accuracy)

        metrics.update(train_metrics)
        metrics.update(test_metrics)

        save_checkpoint(
            exp,
            {
                "epoch": epoch,
                "params": vars(args),
                "model": model.state_dict(),
                "optim": optimizer.state_dict(),
                "metrics": metrics,
            },
            is_best,
        )

        exp.push_log(metrics)
        sched_args = metrics["test_acc"] if args.lrschedule == "plateau" else None
        scheduler.step(sched_args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ODENet/ResNet training")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=("mnist", "cifar10", "cifar100", "tiny-imagenet-200"),
        default="mnist",
    )
    parser.add_argument(
        "--augmentation",
        type=str,
        choices=("none", "crop+flip+norm", "crop+jitter+flip+norm"),
        default="none",
    )
    ### parser.add_argument('-m', '--model', type=str, choices=('resnet', 'odenet'), default='odenet')
    ### parser.add_argument('-m', '--model', type=str, choices=('resnet', 'odenet', 'odenetspatial', 'odenetspatialplustwomul'), default='odenet')
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        choices=(
            "resnet",
            "odenet",
            "odenetspatial",
            "odenetspatialplustwomul",
            "odenetsigmoid",
            "odenetsilu",
            "odenetsilurelu",
            "odenetsilusigmoid",
            "odenetvaryinggroup",
            "odenetvaryingnorm",
            "odenetvaryingdrop",
            "odenetfixeddrop",
            "odenetsoftplus",
            "odenetrelusoftplus",
            "odenetelu",
            "odenetreluelu",
            "odenetctnr",
            "odenetcntr",
            "odenetcnrt",
            "odenetinit1en2",
            "odenetinit1en1",
            "odenetinit1ep0",
            "odenetinit1ep1",
            "odenetinit1ep2",
            "odenetinit1ep3",
            "odenetinit1ep4",
            "odenetinit1en3",
            "odenetinit3ep0",
            "odenetinit3ep1",
            "odenetinit3ep2",
            "odenetzerobiasall",
            "odenetzerobiasconv",
            "odenetzerobiasdense",
            "odenetzerobiasnone",
        ),
        default="odenet",
    )
    parser.add_argument(
        "-d",
        "--downsample",
        type=str,
        choices=("ode2", "ode", "residual", "convolution", "minimal", "one-shot"),
        default="residual",
    )
    parser.add_argument(
        "-n",
        "--norm",
        type=str,
        choices=(
            "group",
            "batch",
            "G256",
            "G128",
            "G64",
            "G32",
            "G16",
            "G8",
            "G4",
            "G2",
            "G1",
        ),
        default="group",
    )
    parser.add_argument("-f", "--filters", type=int, default=64)
    parser.add_argument("--dropout", type=float, default=0)

    parser.add_argument("-e", "--epochs", type=int, default=100)
    parser.add_argument("-b", "--batch-size", type=int, default=128)
    parser.add_argument("--batch-accumulation", type=int, default=1)
    parser.add_argument(
        "-o", "--optim", type=str, choices=("sgd", "adam"), default="sgd"
    )
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument(
        "--lrschedule",
        type=str,
        choices=("fixed", "plateau", "cosine"),
        default="plateau",
    )
    parser.add_argument("--lrcycle", type=int, default=0)
    parser.add_argument("-p", "--patience", type=int, default=10)
    parser.add_argument("--wd", type=float, default=0, help="weight decay")

    parser.add_argument("--no-cuda", dest="cuda", action="store_false")
    parser.set_defaults(cuda=True)

    parser.add_argument("--method", default="dopri5", choices=("dopri5", "adams"))
    parser.add_argument("-t", "--tol", type=float, default=1e-3)
    parser.add_argument("-a", "--adjoint", default=False, action="store_true")

    parser.add_argument("-r", "--resume", action="store_true", default=False)
    parser.add_argument("-s", "--seed", type=int, default=23)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args.device = torch.device(
        "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
    )

    main(args)
