import argparse
import logging
import logging.config
import json
import random
from pathlib import Path
from time import time
from collections import OrderedDict

import torch

import numpy as np

from utils.functions import get_log_config
from models.get_model import get_model
from compute.lipschitz import compute_lipschitz_upper_bound_per_layer, compute_final_upper_bound


def rename_keys_resnet50(state_dict):
    new_state_dict = []

    replacements = {
        "module.conv1": "layers.0",
        "module.bn1": "layers.1",
        "module.layer1": "layers.4",
        "module.layer2": "layers.5",
        "module.layer3": "layers.6",
        "module.layer4": "layers.7",
        "module.fc": "layers.10",
    }

    for k, v in state_dict.items():
        # rename the key
        l = replacements[f"{k.split('.')[0]}.{k.split('.')[1]}"]
        r = ".".join(k.split(".")[2:])
        new_k = f"{l}.{r}"
        # append to new dict
        new_state_dict.append((new_k, v))

    return OrderedDict(new_state_dict)


def format_to_string(per_layer_Lipschitzness: list | dict | torch.Tensor):
    """Converts all torch.Tensor elements in cur_lip to float for json serialisation.

    Parameters
    ----------
    per_layer_Lipschitzness
        Object that needs to be searlised

    Returns
    -------
        `per_layer_Lipschitzness` but with all torch.Tensors converted to floats.

    Raises
    ------
    TypeError
        This exception is raised when `per_layer_Lipschitzness` type does not match torch.Tensor, dict or list

    """
    if isinstance(per_layer_Lipschitzness, torch.Tensor):
        return float(per_layer_Lipschitzness.item())
    if isinstance(per_layer_Lipschitzness, list):
        return [format_to_string(el) for el in per_layer_Lipschitzness]
    if isinstance(per_layer_Lipschitzness, dict):
        temp = {}
        for k, v in per_layer_Lipschitzness.items():
            temp[k] = format_to_string(v)
        return temp

    raise TypeError(
        f"Lipschitz per layer type {type(per_layer_Lipschitzness)} is unknown for this function."
    )


if __name__ == "__main__":
    # setup random seed
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--width", dest="width", type=int, choices=[18, 34, 50, 101, 152])
    parser.add_argument("--weights", dest="weights", type=str, choices=["INIT", "PRETRAINED"])

    parser.add_argument("--runtimestamp", dest="runtimestamp", default=int(time()), type=int)

    parser.add_argument("--path", dest="path", type=str)

    parser.add_argument("--checkpoint_path", dest="checkpoint_path", type=str, default=None)
    parser.add_argument("--checkpoint_epoch", dest="checkpoint_epoch", type=int, default=None)

    # parse args
    args = parser.parse_args()

    path = Path(args.path)
    if args.checkpoint_epoch is not None:
        run_name = f"upper_lip_{args.weights}+ResNet_{args.width}+checkpoint_{args.checkpoint_epoch}+runtimestamp_{args.runtimestamp}"
    else:
        run_name = f"upper_lip_{args.weights}+ResNet_{args.width}+runtimestamp_{args.runtimestamp}"

    pretrained = args.weights == "PRETRAINED"

    # save arguments
    path_to_args = path / "args"
    path_to_args.mkdir(parents=True, exist_ok=True)

    with (path_to_args / f"{run_name}.json").open(mode="w") as f:
        json.dump(vars(args), f)

    # check paths
    path_to_run_data = path / "runs" / run_name
    path_to_run_data.mkdir(parents=True, exist_ok=True)

    # setup logging
    log_config = get_log_config(path, run_name)
    logging.config.dictConfig(log_config)

    logging.info(f"Starting run {run_name}")
    logging.info("Run params:")
    logging.info(json.dumps(vars(args)))
    logging.info("-" * 50)

    # prepare computing
    imagenet_dims = [[3, 224, 224], 1000]
    model = get_model(f"ResNet_{args.width}", dims=imagenet_dims, pretrained=pretrained)

    # load the checkpoint if one is specified
    if args.checkpoint_path is not None and args.checkpoint_epoch is not None:
        if args.width != 50:
            logging.error("Checkpoint loading is only supported for ResNet50.")
            exit(1)

        # load checkpoints
        checkpoint_path = Path(args.checkpoint_path)
        state_dict = torch.load(
            checkpoint_path / "models" / f"checkpoint_{args.checkpoint_epoch}.pth.tar",
            map_location=torch.device("cpu"),
        )

        # reformat checkpoints for the current ResNet setup
        renamed_state_dict = rename_keys_resnet50(state_dict["state_dict"])

        # load the checkpoint
        model.load_state_dict(renamed_state_dict)
        logging.info(
            f"Start computing the upper Lipschitz bound for a pretrained model at epoch {args.checkpoint_epoch}..."
        )
    else:
        if pretrained:
            logging.info(
                "Start computing the upper Lipschitz bound for a pretrained model at the last epoch..."
            )
        else:
            logging.info(
                "Start computing the upper Lipschitz bound for a model at initialisation..."
            )

    # computation

    # device = torch.device("cuda:0")
    device = torch.device("cpu")

    per_layer_Lipschitzness = [torch.tensor(1.0, dtype=torch.float64)]
    for i in range(len(model.layers)):
        logging.info(f"Processing layer {i+1}/{len(model.layers)}...")
        cur_lip = compute_lipschitz_upper_bound_per_layer(
            model.layers[i].to(device), model.layer_input_shapes[i], torch.float64
        )
        logging.info(f"\tLipschitz for layer {i+1} = {cur_lip}")
        per_layer_Lipschitzness.append(cur_lip)

    t = {}
    t["upper_lip_per_layer"] = format_to_string(per_layer_Lipschitzness)
    t["upper_lip"] = float(compute_final_upper_bound(per_layer_Lipschitzness, torch.float64))

    logging.info(f"Upper Lipschitz for the model = {t['upper_lip']}")

    with (path_to_run_data / "lip.json").open("w") as f:
        json.dump(t, f)

    logging.info("Experiment finished successfuly!")
