import argparse
import logging
import logging.config
import json
import random
from pathlib import Path
from time import time
from collections.abc import Sequence
from collections import OrderedDict

import torch
from torch.utils.data import DataLoader

import numpy as np

from utils.enums import Datasets, Devices
from utils.functions import get_log_config
from models.get_model import get_model
from models.base_model import SequentialModel
from compute.grad_wrt_to_inputs import compute_jac_norm


def rename_keys_resnet50(state_dict):
    new_state_dict = []

    replacements = {
        "module.conv1": "layers.0",
        "module.bn1": "layers.1",
        "module.layer1": "layers.4",
        "module.layer2": "layers.5",
        "module.layer3": "layers.6",
        "module.layer4": "layers.7",
        "module.fc": "layers.10",
    }

    for k, v in state_dict.items():
        # rename the key
        l = replacements[f"{k.split('.')[0]}.{k.split('.')[1]}"]
        r = ".".join(k.split(".")[2:])
        new_k = f"{l}.{r}"
        # append to new dict
        new_state_dict.append((new_k, v))

    return OrderedDict(new_state_dict)


def compute_lower_lip_for_resnet_on_imagenet(
    path_to_run_data: Path,
    model: SequentialModel,
    train_dataset: Sequence,
    batch_size: int,
    save_every_k_images: int,
    split_index: int,
    regular_split_len: int,
):
    # setup random seed
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = Devices.GPU.value

    model = model.to(device)
    model.eval()

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
    )

    norms = np.array([])
    # indices are in order, since shuffle = False
    indices = list(range(len(train_dataset)))
    # use split info to determine indices in the unsplitted dataset
    indices = [i + (split_index - 1) * regular_split_len for i in indices]

    # compute per-batch stats and save them to a file
    for i, (x_batch, _) in enumerate(train_dataloader):
        logging.info(f"Processing batch {i+1}/{len(train_dataloader)}...")

        x_batch = x_batch.to(device)

        curr_norms = compute_jac_norm(model, x_batch)
        norms = np.concatenate([norms, curr_norms.cpu().detach().numpy()])

        # save
        if len(norms) > save_every_k_images:
            logging.info(f"Saving samples...")
            with (path_to_run_data / "lip_per_img.csv").open("a") as f:
                to_write = ""
                for i in range(len(norms)):
                    to_write += f"{indices.pop(0)},{norms[i]}\n"
                f.write(to_write)
            norms = np.array([])

    # save for the last batch
    if len(norms) > 0:
        logging.info(f"Saving samples...")
        with (path_to_run_data / "lip_per_img.csv").open("a") as f:
            to_write = ""
            for i in range(len(indices)):
                to_write += f"{indices.pop(0)},{norms[i]}\n"
            f.write(to_write)
        norms = np.array([])

    logging.info(f"Processing done!")


if __name__ == "__main__":
    # setup random seed
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--width", dest="width", type=int, choices=[18, 34, 50, 101, 152])
    parser.add_argument("--weights", dest="weights", type=str, choices=["INIT", "PRETRAINED"])
    parser.add_argument("--nsplits", dest="nsplits", type=int, default=1)
    parser.add_argument("--split_index", dest="split_index", type=int, default=1)

    parser.add_argument("--batch_size", dest="batch_size", default=8, type=int)
    parser.add_argument("--save_every_k_images", dest="save_every_k_images", default=50, type=int)

    parser.add_argument("--runtimestamp", dest="runtimestamp", default=int(time()), type=int)

    parser.add_argument(
        "--dataset_path", dest="dataset_path", type=str, default="/local/home/stuff"
    )
    parser.add_argument("--path", dest="path", type=str)

    parser.add_argument("--train_ds_len", dest="train_ds_len", type=int, default=1_281_167)
    parser.add_argument("--checkpoint_path", dest="checkpoint_path", type=str, default=None)
    parser.add_argument("--checkpoint_epoch", dest="checkpoint_epoch", type=int, default=None)

    parser.add_argument(
        "--use_convex_combo_ds", dest="use_convex_combo_ds", type=int, default=0, choices=[0, 1]
    )
    parser.add_argument(
        "--convex_combo_lambda", dest="convex_combo_lambda", type=float, default=0.5
    )
    parser.add_argument(
        "--convex_indices_filepath", dest="convex_indices_filepath", type=str, default=None
    )

    # parse args
    args = parser.parse_args()

    path = Path(args.path)
    dataset_path = Path(args.dataset_path)

    ds_name = "ImageNet"
    if args.use_convex_combo_ds == 1:
        ds_name = "ConvexComboImageNet"

    if args.checkpoint_epoch is not None:
        run_name = f"lower_lip_{args.weights}+{ds_name}+ResNet_{args.width}+checkpoint_{args.checkpoint_epoch}+part_{args.split_index}-{args.nsplits}+runtimestamp_{args.runtimestamp}"
    else:
        run_name = f"lower_lip_{args.weights}+{ds_name}+ResNet_{args.width}+part_{args.split_index}-{args.nsplits}+runtimestamp_{args.runtimestamp}"

    pretrained = args.weights == "PRETRAINED"

    # save arguments
    path_to_args = path / "args"
    path_to_args.mkdir(parents=True, exist_ok=True)

    with (path_to_args / f"{run_name}.json").open(mode="w") as f:
        json.dump(vars(args), f)

    # check paths
    path_to_run_data = path / "runs" / run_name
    path_to_run_data.mkdir(parents=True, exist_ok=True)

    # setup logging
    log_config = get_log_config(path, run_name)
    logging.config.dictConfig(log_config)

    logging.info(f"Starting run {run_name}")
    logging.info("Run params:")
    logging.info(json.dumps(vars(args)))
    logging.info("-" * 50)

    assert args.train_ds_len > args.nsplits
    # split index starts from 1 and goes to nsplits
    assert 1 <= args.split_index <= args.nsplits

    # determine which split to use
    split_len = args.train_ds_len // args.nsplits
    if args.split_index == args.nsplits:
        # add the remainder of the data to the last split
        split_len += args.train_ds_len % args.nsplits

    # prepare dataset and model
    if args.use_convex_combo_ds == 1:
        assert 0.0 < args.convex_combo_lambda < 1.0

        if args.convex_indices_filepath is None:
            train_dataset, _, _, _, dims = Datasets.ConvexComboImageNet.value(
                convex_lambda=args.convex_combo_lambda,
                one_hot_encode_y=False,
                alpha_shuffle=0.0,
                dataset_path=dataset_path,
                train_start=(args.split_index - 1) * split_len,
                train_len=split_len,
            )
        else:
            train_dataset, _, _, _, dims = Datasets.ConvexComboImageNet.value(
                convex_indices_filepath=Path(args.convex_indices_filepath),
                convex_lambda=args.convex_combo_lambda,
                one_hot_encode_y=False,
                alpha_shuffle=0.0,
                dataset_path=dataset_path,
                train_start=(args.split_index - 1) * split_len,
                train_len=split_len,
            )

    else:
        train_dataset, _, _, _, dims = Datasets.ImageNet.value(
            noise_scale=0.0,
            one_hot_encode_y=False,
            alpha_shuffle=0.0,
            dataset_path=dataset_path,
            train_start=(args.split_index - 1) * split_len,
            train_len=split_len,
        )
    model = get_model(f"ResNet_{args.width}", dims=dims, pretrained=pretrained)

    # load the checkpoint if one is specified
    if args.checkpoint_path is not None and args.checkpoint_epoch is not None:
        if args.width != 50:
            logging.error("Checkpoint loading is only supported for ResNet50.")
            exit(1)

        # load checkpoints
        checkpoint_path = Path(args.checkpoint_path)
        state_dict = torch.load(
            checkpoint_path / "models" / f"checkpoint_{args.checkpoint_epoch}.pth.tar",
            map_location=torch.device("cpu"),
        )

        # reformat checkpoints for the current ResNet setup
        renamed_state_dict = rename_keys_resnet50(state_dict["state_dict"])

        # load the checkpoint
        model.load_state_dict(renamed_state_dict)
        logging.info(
            f"Start computing the lower Lipschitz bound for a pretrained model at epoch {args.checkpoint_epoch}..."
        )
    else:
        if pretrained:
            logging.info(
                "Start computing the lower Lipschitz bound for a pretrained model at the last epoch..."
            )
        else:
            logging.info(
                "Start computing the lower Lipschitz bound for a model at initialisation..."
            )

    # prepare the file for results
    with (path_to_run_data / "lip_per_img.csv").open("w") as f:
        f.write("indices,norms\n")

    compute_lower_lip_for_resnet_on_imagenet(
        path_to_run_data,
        model,
        train_dataset,
        args.batch_size,
        args.save_every_k_images,
        args.split_index,
        args.train_ds_len // args.nsplits,  # regular split len
    )
    logging.info("Computation finished successfuly!")
