import argparse
import json
import random
import logging
import logging.config
from time import time
from pathlib import Path

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter


from models.get_model import get_model
from train_utils.dataclasses import RunSetup
from train_utils.run_epoch import run_epoch
from compute.param_vector import get_vector_of_params
from utils.functions import (
    build_run_name,
    get_log_config,
    read_state_dict,
    get_last_checkpointed_epoch,
    get_tb_layouts,
    get_results,
    get_scalar,
)
from utils.enums import Datasets, Losses, Optims, LR_Schedulers, Devices


def train_model(run: RunSetup, run_name: str):
    # setup random seed
    torch.manual_seed(run.seed)
    torch.cuda.manual_seed(run.seed)
    np.random.seed(run.seed)
    random.seed(run.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # check if we need ohe encoding
    if run.loss == Losses.MSE:
        one_hot_encode_y = True
    else:
        one_hot_encode_y = False

    # load dataset
    (
        train_dataset,
        test_dataset,
        train_dataloader,
        test_dataloader,
        dataset_dims,
    ) = run.dataset.value(
        batch_size=run.batch_size,
        batch_size_test=run.batch_size_test,
        noise_scale=run.dataset_noise,
        one_hot_encode_y=one_hot_encode_y,
        alpha_shuffle=run.alpha_shuffle,
        dataset_path=run.dataset_path,
        num_workers=run.data_num_workers,
    )
    # init model
    device = run.device.value
    model = get_model(run.model_name, dataset_dims, seed=run.seed).to(device)

    # prepare for training
    loss_f = run.loss.value()
    optimiser = run.optim.value(model.parameters(), lr=run.lr)

    # lr decay
    lr_scheduler = run.lr_scheduler.value(optimiser, run.batch_size, len(train_dataset))

    # load model to continue training
    if run.load_model:
        # load model at init to get theta_0
        state_dict = read_state_dict(run_name, run.path, device, epoch=0)
        model.load_state_dict(state_dict["model"])
        theta_0 = get_vector_of_params(model)

        # get exact epoch number
        if run.load_from_epoch == -1:
            load_from_epoch = get_last_checkpointed_epoch(run_name, run.path)
        else:
            load_from_epoch = run.load_from_epoch
        # load model and scheduler
        logging.info(f"Continue training from epoch {load_from_epoch}...")
        state_dict = read_state_dict(run_name, run.path, device, epoch=load_from_epoch)
        model.load_state_dict(state_dict["model"])
        model = model.to(device)
        lr_scheduler_updates = state_dict["scheduler"]["_step_count"]
        # step the scheduler
        [lr_scheduler.step() for i in range(lr_scheduler_updates)]

        # get the grad norm from previous epoch
        norm_grad_of_params_this_epoch = get_scalar(get_results(path), "norm_grad_of_params")[1][-1]

    # check that folder for checkpoints exists
    if not (run.path / "checkpoints" / run_name).exists():
        (run.path / "checkpoints" / run_name).mkdir(parents=True)

    # setup layers to record lipschitz
    num_layers = len(model.layers)

    # NOTE: indexation of layers works in the following way:
    # layer 0 == input
    # layer 1 == first layer applied
    # layer 2 == first two layers applied
    # ...

    # Look at the most important layers for FF_ReLU
    # skip 0-th layer (just input), skip linear layers without activation applied
    # layers_to_look_at = list(range(num_layers+1))[2::2]

    # Look only at the last layer
    layers_to_look_at = [num_layers]

    # logging
    writer = SummaryWriter(log_dir=str((run.path / "runs" / run_name).resolve()))
    writer.add_custom_scalars(get_tb_layouts(layers_to_look_at))

    if not run.load_model:
        # run 0th epoch
        (
            train_loss,
            train_accuracy,
            test_loss,
            test_accuracy,
            norm_grad_of_params_this_epoch,
            theta_0,
            _,
        ) = run_epoch(
            0,
            model,
            train_dataloader,
            test_dataloader,
            loss_f,
            optimiser,
            lr_scheduler,
            device,
            layers_to_look_at,
            writer,
            0,
            run.min_num_epochs,
            run.max_num_epochs,
            compute_lip_this_epoch=(run.compute_L_for_first != 0),
        )
        # save model and scheduler
        torch.save(
            {
                "model": model.state_dict(),
                "scheduler": lr_scheduler.state_dict(),
            },
            run.path / "checkpoints" / run_name / "model_on_epoch_0",
        )

    starting_epoch = 1
    if run.load_model:
        starting_epoch = load_from_epoch + 1

    # main training loop
    epoch = 0  # set epoch = 0 in case we train for 0 epochs (this variable is used later for final logs)
    batch_counter = 0

    for epoch in range(starting_epoch, run.max_num_epochs + 1):
        # determine whether to compute Lipschitz this epoch
        compute_lip_this_epoch = (
            (epoch % run.compute_L_every == 0)
            or (epoch <= run.compute_L_for_first)
            or (epoch == run.min_num_epochs)
            or (epoch == run.max_num_epochs)
            or (
                (epoch > run.min_num_epochs)
                and (norm_grad_of_params_this_epoch <= run.target_norm_grad)  ## stopping criteria
            )
        )

        # run epoch
        (
            train_loss,
            train_accuracy,
            test_loss,
            test_accuracy,
            norm_grad_of_params_this_epoch,
            _,
            batch_counter,
        ) = run_epoch(
            epoch,
            model,
            train_dataloader,
            test_dataloader,
            loss_f,
            optimiser,
            lr_scheduler,
            device,
            layers_to_look_at,
            writer,
            batch_counter,
            run.min_num_epochs,
            run.max_num_epochs,
            theta_0,
            compute_lip_this_epoch=compute_lip_this_epoch,
        )

        # save model
        if (
            (epoch % run.save_model_every == 0)
            or (epoch == run.min_num_epochs)
            or (epoch == run.max_num_epochs)
            or (
                (epoch > run.min_num_epochs)
                and (norm_grad_of_params_this_epoch <= run.target_norm_grad)  ## stopping criteria
            )
        ):  # save for last epoch before stop
            # save model and scheduler
            torch.save(
                {
                    "model": model.state_dict(),
                    "scheduler": lr_scheduler.state_dict(),
                },
                run.path / "checkpoints" / run_name / f"model_on_epoch_{epoch}",
            )

        # stop the training loop
        if (norm_grad_of_params_this_epoch <= run.target_norm_grad) and (
            epoch >= run.min_num_epochs
        ):
            break

    # final metrics for Lip
    final_metrics = {
        "hparam/train_loss": train_loss,
        "hparam/train_accuracy": train_accuracy,
        "hparam/test_loss": test_loss,
        "hparam/test_accuracy": test_accuracy,
        "hparam/last_epoch": epoch,
        "hparam/last_batch": batch_counter,
    }

    hparams = run.as_dict()

    writer.add_hparams(hparams, final_metrics)
    writer.flush()
    writer.close()

    # run get results to generate a csv file from TensorBoard logs
    logging.info("Parsing TensorBoard logs...")
    get_results(run.path / "runs" / run_name, regen_csv=True)

    logging.info("Training finished successfully")


if __name__ == "__main__":
    # parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", dest="dataset", type=str)
    parser.add_argument("--dataset_noise", dest="dataset_noise", default=0.0, type=float)
    parser.add_argument("--batch_size", dest="batch_size", default=512, type=int)
    parser.add_argument("--batch_size_test", dest="batch_size_test", default=512, type=int)
    parser.add_argument("--data_num_workers", dest="data_num_workers", default=0, type=int)

    parser.add_argument("--min_num_epochs", dest="min_num_epochs", default=10_000, type=int)
    parser.add_argument("--max_num_epochs", dest="max_num_epochs", default=1_000_000, type=int)

    # save model every X epoch
    parser.add_argument("--save_model_every", dest="save_model_every", default=100, type=int)
    # compute Lipschitz every X epoch
    parser.add_argument("--compute_L_every", dest="compute_L_every", default=100, type=int)
    # compute Lipschitz for the first X epochs
    parser.add_argument("--compute_L_for_first", dest="compute_L_for_first", default=1000, type=int)

    parser.add_argument("--model_name", dest="model_name", type=str)
    parser.add_argument("--loss", dest="loss", type=str)
    parser.add_argument("--optim", dest="optim", type=str)
    parser.add_argument("--lr", dest="lr", type=float)
    parser.add_argument("--lr_scheduler", dest="lr_scheduler", type=str)
    parser.add_argument("--alpha_shuffle", dest="alpha_shuffle", default=0.0, type=float)
    parser.add_argument("--target_norm_grad", dest="target_norm_grad", type=float)

    parser.add_argument("--seed", dest="seed", type=int)

    parser.add_argument("--load_model", dest="load_model", default=0, choices=[0, 1], type=int)
    parser.add_argument("--load_from_epoch", dest="load_from_epoch", default=-1, type=int)
    parser.add_argument("--runtimestamp", dest="runtimestamp", default=int(time()), type=int)

    parser.add_argument("--device", dest="device", type=str, default="CPU")

    parser.add_argument("--dataset_path", dest="dataset_path", type=str)
    parser.add_argument("--path", dest="path", type=str)

    args = parser.parse_args()

    # form run name
    run_name = build_run_name(
        args.dataset,
        args.model_name,
        args.loss,
        args.optim,
        args.lr,
        args.lr_scheduler,
        args.alpha_shuffle,
        args.seed,
        args.target_norm_grad,
        args.runtimestamp,
    )
    path = Path(args.path)
    dataset_path = Path(args.dataset_path)

    # save arguments
    path_to_args = path / "args"
    path_to_args.mkdir(parents=True, exist_ok=True)

    arg_file = path_to_args / f"{run_name}.json"
    if arg_file.exists():
        # dump arguments in a separate key to store old args
        with arg_file.open(mode="r") as f:
            arg_dict = json.load(f)

        if "new_args" not in arg_dict.keys():
            arg_dict["new_args"] = []
        arg_dict["new_args"].append(vars(args))

        with arg_file.open(mode="w") as f:
            json.dump(arg_dict, f)

    else:
        # dump arguments
        with arg_file.open(mode="w") as f:
            json.dump(vars(args), f)

    # setup logging
    logging.config.dictConfig(get_log_config(path, run_name))

    # make some checks
    if args.dataset not in Datasets.__members__:
        logging.error(f"Dataset {args.dataset} is unknown. Known datasets: {Datasets.__members__}")
        exit(1)

    if args.optim not in Optims.__members__:
        logging.error(f"Optimiser {args.optim} is unknown. Known optimsers: {Optims.__members__}")
        exit(1)

    if args.loss not in Losses.__members__:
        logging.error(f"Loss {args.loss} is unknown. Known losses: {Losses.__members__}")
        exit(1)

    if args.min_num_epochs > args.max_num_epochs:
        logging.error(
            f"min_num_epochs {args.min_num_epochs} should be <= than max_num_epochs {args.max_num_epochs}"
        )
        exit(1)

    logging.info(f"Starting run {run_name}")
    logging.info("Run params:")
    logging.info(json.dumps(vars(args)))
    logging.info("-" * 50)

    # start training
    train_model(
        RunSetup(
            dataset=Datasets[args.dataset],
            model_name=args.model_name,
            loss=Losses[args.loss],
            optim=Optims[args.optim],
            lr=args.lr,
            lr_scheduler=LR_Schedulers[args.lr_scheduler],
            target_norm_grad=args.target_norm_grad,
            path=path,
            dataset_path=dataset_path,
            seed=args.seed,
            device=Devices[args.device],
            min_num_epochs=args.min_num_epochs,
            max_num_epochs=args.max_num_epochs,
            save_model_every=args.save_model_every,
            compute_L_every=args.compute_L_every,
            compute_L_for_first=args.compute_L_for_first,
            dataset_noise=args.dataset_noise,
            batch_size=args.batch_size,
            batch_size_test=args.batch_size_test,
            alpha_shuffle=args.alpha_shuffle,
            data_num_workers=args.data_num_workers,
            load_model=args.load_model,
            load_from_epoch=args.load_from_epoch,
            runtimestamp=args.runtimestamp,
        ),
        run_name,
    )
