import math
import argparse
import logging
import pprint
import os
import pickle
import warnings
import itertools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import matplotlib.figure
import matplotlib.pyplot as plt
import torch
import numpy as np
import matplotlib

import utils.logging
import depth_analysis
import utils.plotting
import networks
import utils.argparsers
import utils.path_config

FigAx = Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]

paths = utils.path_config.get_paths()

warnings.filterwarnings("error")

logging_level = 15

logger = logging.getLogger(__name__)
logger.setLevel(logging_level)

log_dir_str = paths["logs"]
standard_streamhandler = utils.logging.get_standard_streamhandler()
standard_filehandler = utils.logging.get_standard_filehandler(log_dir_str)

logger.addHandler(standard_streamhandler)
logger.addHandler(standard_filehandler)

for handler in logger.handlers:
    logger.info(handler)


def _get_dataloader(dataset_name: str,
                    num_rows: int,
                    dim: int,
                    batch_size: int,
                    dtype: torch.dtype) -> torch.utils.data.DataLoader:
    dataloader_kwargs = {
        "batch_size": batch_size,
        "shuffle": True
    }
    if "unitcube" == dataset_name:
        x = torch.rand(num_rows, dim, dtype=dtype)
    elif "gaussian" == dataset_name:
        x = torch.randn(num_rows, dim, dtype=dtype)
    elif "dirichlet" == dataset_name:
        alpha = np.ones(dim,) / dim
        dirichlet_sample = torch.tensor(np.random.dirichlet(alpha, size=num_rows), dtype=dtype)
        x = dirichlet_sample
    else:
        raise ValueError(f"dataset_name {dataset_name} not configured")
    # x = x.to(device)
    x_max = x.max(1).values
    # dataset = torch.utils.data.TensorDataset(x.type(torch.FloatTensor),
    #                                          x_max.type(torch.FloatTensor))
    dataset = torch.utils.data.TensorDataset(x.type(dtype),
                                             x_max.type(dtype))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **dataloader_kwargs)
    return dataloader


def _init_weights_xavier(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)


def _init_weights_kaiming(m):
    if isinstance(m, torch.nn.Linear):
        # torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.kaiming_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)


def build_identifier(dim: int,
                     seed: int,
                     dataset_name: str,
                     benchmark_name: str,
                     residualization_name: Optional[str],
                     optimize_criterion_name: str,
                     evaluate_criterion_name: str) -> str:
    # identifier = f"{dim}_{seed}_{dataset_name}_{benchmark_name}"
    identifier = f"{dim}_{seed}_{dataset_name}_{benchmark_name}_{residualization_name}_{optimize_criterion_name}_{evaluate_criterion_name}"
    return identifier


def invert_identifier(identifier: str) -> tuple:
    # identifier = '9_111_unitcube_mediumapprox_None_None'
    spl = identifier.split("_")

    dim = int(spl[0])
    seed = int(spl[1])
    dataset_name = str(spl[2])
    benchmark_name = str(spl[3])
    residualization_name = utils.argparsers.none_or_str(spl[4])
    optimize_criterion_name = str(spl[5])
    evaluate_criterion_name = str(spl[6])
    return dim, seed, dataset_name, benchmark_name, residualization_name, optimize_criterion_name, evaluate_criterion_name


def build_arg_strs() -> List[str]:
    dim_list = [8, 9]
    seed_list = [112]
    initialization_name_list = ["kaiming"]
    # dataset_name_list = ["unitcube"]
    dataset_name_list = ["dirichlet"]
    benchmark_name_list = ["smallapprox", "mediumapprox", "bigapprox", "maxer"]
    residualization_names = ["benchmark", None]
    num_replications_list = [10]

    coords_tuple = (dim_list,
                    seed_list,
                    initialization_name_list,
                    dataset_name_list,
                    benchmark_name_list,
                    residualization_names,
                    num_replications_list)
    all_args = list(itertools.product(*coords_tuple))

    arg_strs = []
    for dim, seed, initialization_name, dataset_name, benchmark_name, residualization_name, num_replications in all_args:
        # print(dim, seed, initialization_name, dataset_name, benchmark_name, residualization_name, num_replications)
        if "maxer" == benchmark_name and "benchmark" == residualization_name:
            continue
        arg_str = f" --dim {dim} --seed {seed} --initialization_name {initialization_name} --dataset_name {dataset_name} --benchmark_name {benchmark_name} --residualization_name {residualization_name} --num_replications {num_replications}"
        print(f'"{arg_str}", ')
        # arg_str.replace("--", "|")
        arg_strs += [arg_str]
    return arg_strs


def generate_results(fitting_results_mult: list,
                     mults: List[float] ) -> Dict[str, Any]:
    num_mults = len(mults)
    optimize_criterion_values = dict(train=torch.full((num_mults,), math.nan),
                                     test=torch.full((num_mults,), math.nan))

    evaluate_criterion_values = dict(train=torch.full((num_mults,), math.nan),
                                     test=torch.full((num_mults,), math.nan))
    for idx, mult in enumerate(mults):
        # idx = 0; mult = mults[idx]
        frm = fitting_results_mult[idx]
        terminal_epoch = frm["terminal_epoch"]
        ocv = frm['optimize_criterion_values']
        ecv = frm["evaluate_criterion_values"]

        optimize_criterion_values["train"][idx] = ocv["train"][terminal_epoch]
        optimize_criterion_values["test"][idx] = ocv["test"][terminal_epoch]

        evaluate_criterion_values["train"][idx] = ecv["train"][terminal_epoch]
        evaluate_criterion_values["test"][idx] = ecv["test"][terminal_epoch]

    results = {
        "evaluate_criterion_values": evaluate_criterion_values,
        "optimize_criterion_values": optimize_criterion_values,
        "mults": mults
    }
    return results


def _mutate_dataloader_by_model(dataloader: torch.utils.data.DataLoader,
                                model: torch.nn.Module,
                                device: torch.device) -> torch.utils.data.DataLoader:
    dataloader_kwargs = {
        "batch_size": dataloader.batch_size,
        "shuffle": True
    }
    # y = dataloader.dataset.tensors[1]
    # device = y.device
    new_y_batches = [None] * len(dataloader)
    for batch_idx, (x, y) in enumerate(dataloader):
        # (x, y) = next(iter(dataloader))
        x, y = x.to(device), y.to(device)
        new_y_batches[batch_idx] = (y - model(x).reshape(y.shape)).detach()
    new_y = torch.cat(new_y_batches)
    dataset = torch.utils.data.TensorDataset(dataloader.dataset.tensors[0], new_y)
    new_dataloader = torch.utils.data.DataLoader(dataset,
                                                 **dataloader_kwargs)
    return new_dataloader


def initialize_weights(model: torch.nn.Module,
                       benchmark_model: torch.nn.Module,
                       initialization_name: str) -> torch.nn.Module:
    if initialization_name is None:
        pass
    elif "xavier" == initialization_name:
        model.apply(_init_weights_xavier)
    elif "kaiming" == initialization_name:
        model.apply(_init_weights_kaiming)
    elif "smart" == initialization_name:
        model.apply(_init_weights_xavier)
        arch_model = [type(_) for _ in model]
        arch_benchmark = [type(_) for _ in benchmark_model]
        assert arch_model == arch_benchmark
        num_relu = math.ceil(math.log2(dim))
        assert len(benchmark_model) == 1 + 2 * num_relu
        for idx in range(num_relu):
            # idx = 0
            layer_idx = 2 * idx
            to_assign = benchmark_model[layer_idx].weight
            if 0 == idx:
                with torch.no_grad():
                    model[layer_idx].weight[:to_assign.shape[0], :] = to_assign.clone().detach()
            elif idx == num_relu - 1:
                pass
            else:
                pass
    else:
        raise ValueError(f"initialization_name = {initialization_name} not configured")
    return model


def test_loop(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              optimize_criterion: Callable,
              evaluate_criterion: Callable,
              device: torch.device) -> Tuple[float, float]:
    dataloader_len = len(dataloader)

    sizes = torch.full((dataloader_len,), math.nan)
    optimize_losses = torch.full((dataloader_len,), math.nan)
    evaluate_losses = torch.full((dataloader_len,), math.nan)

    for batch_idx, (x, y) in enumerate(dataloader):
        # (x, y) = next(iter(dataloader))
        x, y = x.to(device), y.to(device)

        y_pred = model(x).flatten()
        evaluate_loss = evaluate_criterion(y_pred, y)
        optimize_loss = optimize_criterion(y_pred, y)
        sizes[batch_idx] = y.shape[0]

        evaluate_losses[batch_idx] = evaluate_loss.item()
        optimize_losses[batch_idx] = optimize_loss.item()

    optimize_loss = (sizes * optimize_losses).sum() / sizes.sum()
    evaluate_loss = (sizes * evaluate_losses).sum() / sizes.sum()
    return optimize_loss, evaluate_loss


def train_loop(model: torch.nn.Module,
                train_dataloader: torch.utils.data.DataLoader,
                optimize_criterion: Callable,
                evaluate_criterion: Callable,
                optimizer: torch.optim.Optimizer,
                device: torch.device) -> Tuple[float, float]:
    train_dataloader_len = len(train_dataloader)
    train_sizes = torch.full((train_dataloader_len,), torch.nan)
    optimize_losses = torch.full((train_dataloader_len,), torch.nan)
    evaluate_losses = torch.full((train_dataloader_len,), torch.nan)

    for batch_idx, (x, y) in enumerate(train_dataloader):
        # (x, y) = next(iter(dataloader))
        x, y = x.to(device), y.to(device)

        y_pred = model(x).flatten()
        optimize_loss_value = optimize_criterion(y_pred, y)

        optimizer.zero_grad()
        optimize_loss_value.backward()
        optimizer.step()

        evaluate_loss_value = evaluate_criterion(y_pred, y)

        train_sizes[batch_idx] = y.shape[0]
        optimize_losses[batch_idx] = optimize_loss_value.item()
        evaluate_losses[batch_idx] = evaluate_loss_value.item()

    optimize_loss_value = (train_sizes * optimize_losses).sum() / train_sizes.sum()
    evaluate_loss_value = (train_sizes * evaluate_losses).sum() / train_sizes.sum()
    return optimize_loss_value, evaluate_loss_value


def fit_fancy(num_epochs: int,
              train_dataloader: torch.utils.data.DataLoader,
              test_dataloader: torch.utils.data.DataLoader,
              model: torch.nn.Module,
              train_criterion: Callable,
              evaluate_criterion: Callable,
              optim_class: torch.optim.Optimizer,
              optim_kwargs: dict,
              early_stopping_params: Dict[str, Any],
              device: torch.device) -> Dict[str, Any]:
    patience_epochs = early_stopping_params["patience_epochs"]
    min_improvement = early_stopping_params["min_improvement"]

    optimizer = optim_class(model.parameters(), **optim_kwargs)

    optimize_criterion_values = dict(test=torch.full((num_epochs,), math.nan),
                                     train=torch.full((num_epochs,), math.nan))
    evaluate_criterion_values = dict(test=torch.full((num_epochs,), math.nan),
                                     train=torch.full((num_epochs,), math.nan))
    best_err = math.inf
    epoch_iterator = range(num_epochs)

    for epoch_idx in epoch_iterator:
        optimize_loss_value, evaluate_loss_value = train_loop(model,
                                                           train_dataloader,
                                                           optimize_criterion,
                                                           evaluate_criterion,
                                                           optimizer,
                                                           device)
        optimize_criterion_values["train"][epoch_idx] = optimize_loss_value
        evaluate_criterion_values["train"][epoch_idx] = evaluate_loss_value

        optimize_loss_value, evaluate_loss_value = test_loop(model,
                                                          test_dataloader,
                                                          optimize_criterion,
                                                          evaluate_criterion,
                                                          device)
        optimize_criterion_values["test"][epoch_idx] = optimize_loss_value
        evaluate_criterion_values["test"][epoch_idx] = evaluate_loss_value

        test_err = evaluate_loss_value
        best_err = min(best_err, test_err)

        e0 = max(0, epoch_idx - patience_epochs)

        improvements = evaluate_criterion_values["test"][e0:epoch_idx] - best_err
        stop_early_loss_improvement = epoch_idx >= patience_epochs and \
                                      max(improvements) < min_improvement
        stop_early_overfitting = False
        stop_early = stop_early_overfitting or \
                     stop_early_loss_improvement
        if stop_early:
            break

    fitting_results = {
        "optimize_criterion_values": optimize_criterion_values,
        "evaluate_criterion_values": evaluate_criterion_values,
        "model": model,
        "terminal_epoch": epoch_idx
    }
    return fitting_results


def run_one_replicate(mults: List[float],
                      optim_class: torch.optim.Optimizer,
                      early_stopping_params: Dict[str, Any],
                      dataset_name: str,
                      optimize_criterion: Callable,
                      evaluate_criterion: Callable,
                      num_rows_train: int,
                      num_rows_test: int,
                      dim: int,
                      batch_size: int,
                      residualization_name: str,
                      initialization_name: str,
                      dtype: torch.dtype) -> Dict[str, Any]:
    raw_train_dataloader = _get_dataloader(dataset_name, num_rows_train, dim, batch_size, dtype)
    raw_test_dataloader = _get_dataloader(dataset_name, num_rows_test, dim, batch_size, dtype)

    model = networks.build_shallowest_network(betas, dim).to(device).to(dtype)
    benchmark_model = networks.canonicalize_network(model)
    benchmark_hidden_layer_widths = [_.out_features for _ in benchmark_model if type(_) == torch.nn.Linear][:-1]
    if "benchmark" == residualization_name:
        train_dataloader = _mutate_dataloader_by_model(raw_train_dataloader, benchmark_model, device)
        test_dataloader = _mutate_dataloader_by_model(raw_test_dataloader, benchmark_model, device)
        early_stopping_params = dict(patience_epochs=5,
                                     min_improvement=0.0)
    elif residualization_name is None:
        train_dataloader = raw_train_dataloader
        test_dataloader = raw_test_dataloader
    else:
        raise ValueError(f"Do not know about residualization_name {residualization_name}")
    num_mults = len(mults)
    fitting_results_mult = [None] * num_mults
    for idx, mult in enumerate(mults):
        # idx = 3; mult = mults[idx]
        hidden_layer_widths = [int(math.ceil(mult * _)) for _ in
                               benchmark_hidden_layer_widths]
        logger.info(f"hidden_layer_widths = {hidden_layer_widths}")
        layers = networks.build_relu_layers(dim,
                                            hidden_layer_widths,
                                            output_dim=1,
                                            include_bias=include_bias)
        model = torch.nn.Sequential(*layers).to(dtype).to(device)
        model = initialize_weights(model, benchmark_model, initialization_name)
        fitting_results = fit_fancy(num_epochs,
                                            train_dataloader,
                                            test_dataloader,
                                            model,
                                            optimize_criterion,
                                            evaluate_criterion,
                                            optim_class,
                                            optim_kwargs,
                                            early_stopping_params,
                                            device)
        fitting_results_mult[idx] = fitting_results
        if False:
            plot_x = torch.arange(fitting_results["terminal_epoch"] + 1)
            # plot_y1 = [fitting_results["optimize_criterion_values"]["train"],
            #            fitting_results["optimize_criterion_values"]["test"]]

            fig, axs = plt.subplots(1, 2, squeeze=False)
            lines0 = axs[0, 0].plot(plot_x, fitting_results["optimize_criterion_values"]["train"])
            lines1 = axs[0, 0].plot(plot_x, fitting_results["optimize_criterion_values"]["test"])
            axs[0, 0].set_title("Optimize")

            lines0 = axs[0, 1].plot(plot_x, fitting_results["evaluate_criterion_values"]["train"])
            lines1 = axs[0, 1].plot(plot_x, fitting_results["evaluate_criterion_values"]["test"])
            axs[0, 1].set_title("Evaluate")

    results = generate_results(fitting_results_mult, mults)
    return results


def get_criterion(criterion_name: str) -> Callable:
    if criterion_name == "linf":
        criterion = lambda y1, y2: torch.linalg.norm(y1.squeeze() - y2.squeeze(),
                                                     ord=math.inf)
    elif "l2" == criterion_name:
        criterion = torch.nn.MSELoss()
    else:
        raise ValueError(f"'{criterion_name}' not configured")
    return criterion


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=11)
    parser.add_argument("--cuda_wanted", type=bool, default=True,
                        action=argparse.BooleanOptionalAction)
    parser.add_argument("--dim", type=int, default=8)
    parser.add_argument("--num_epochs", type=int, default=500)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--include_bias", type=bool, default=True)
    parser.add_argument("--num_rows", type=int, default=10_000)
    parser.add_argument("--dataset_name", type=str, default="unitcube")
    parser.add_argument("--benchmark_name", type=str, default="maxer")
    parser.add_argument("--mult_name", type=str, default="above")
    parser.add_argument("--dtype", type=str, default="float32")
    parser.add_argument("--initialization_name",
                        type=utils.argparsers.none_or_str,
                        default=None)
    parser.add_argument("--residualization_name",
                        type=utils.argparsers.none_or_str,
                        default=None)
    parser.add_argument("--optimize_criterion_name",
                        type=str,
                        default="linf",
                        choices=["l2", "linf"])
    parser.add_argument("--evaluate_criterion_name",
                        type=str,
                        default="linf",
                        choices=["l2", "linf"])
    parser.add_argument("--num_replications", type=int, default=1)

    parser.add_argument("--mode", type=str, default="", help="Swallow PyCharm args")
    parser.add_argument("--port", type=str, default="", help="Swallow PyCharm args")
    parser.add_argument("-f", type=str, default="", help="Swallow IPython arg")

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    cuda_available = torch.cuda.is_available()
    cuda_used = args.cuda_wanted and cuda_available
    device = torch.device("cuda") if cuda_used else torch.device("cpu")
    logger.debug(f"torch.__version__ = {torch.__version__}")
    if cuda_used:
        current_device = torch.cuda.current_device()
        current_device_properties = torch.cuda.get_device_properties(current_device)
        logger.info(f"Running on {current_device_properties}")
        logger.debug(f"torch.version.cuda = {torch.version.cuda}")
    logger.info(pprint.pformat(args.__dict__))

    dim = args.dim
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    include_bias = args.include_bias
    dataset_name = args.dataset_name
    benchmark_name = args.benchmark_name
    mult_name = args.mult_name
    initialization_name = args.initialization_name
    residualization_name = args.residualization_name
    optimize_criterion_name = args.optimize_criterion_name
    evaluate_criterion_name = args.evaluate_criterion_name

    if "float16" == args.dtype:
        dtype = torch.float16
    else:
        dtype = torch.float32

    num_rows = args.num_rows
    num_rows_train = num_rows
    num_rows_test = num_rows
    num_replications = args.num_replications
    experiment_ident = build_identifier(dim,
                                        args.seed,
                                        dataset_name,
                                        benchmark_name,
                                        residualization_name,
                                        optimize_criterion_name,
                                        evaluate_criterion_name)
    logger.info(f"running experiment {experiment_ident}")

    optim_kwargs = {"lr": 0.005,
                    "betas": (0.9, 0.999)}
    optimize_criterion = get_criterion(optimize_criterion_name)
    evaluate_criterion = get_criterion(evaluate_criterion_name)
    # benchmark_name = "maxer"
    if "maxer" == benchmark_name:
        betas = torch.zeros((dim + 1,))
        betas[-1] = 1.0
    elif "bigapprox" == benchmark_name:
        ks = list(range(dim))
        soln = depth_analysis.get_dk_situation(dim, ks)
        betas = torch.zeros((dim + 1,))
        try:
            betas[:-1] = torch.tensor(soln["argmin"])
        except:
            betas[:-1] = soln["argmin"]
    elif "mediumapprox" == benchmark_name:
        ks = [0, 1, dim - 2, dim - 1]
        soln = depth_analysis.get_dk_situation(dim, ks)
        betas = torch.zeros((dim + 1,))
        betas[:-1] = torch.tensor(soln["argmin"])
    elif "smallapprox" == benchmark_name:
        ks = [0, 1, dim - 1]
        soln = depth_analysis.get_dk_situation(dim, ks)
        betas = torch.zeros((dim + 1,))
        betas[:-1] = torch.tensor(soln["argmin"])
    else:
        raise ValueError(f"I do not know about benchmark {benchmark_name}")

    optim_class = torch.optim.Adam
    early_stopping_params = dict(patience_epochs=5,
                                 min_improvement=.005)
    if "above" == mult_name:
        mult_step = 1
        max_mult = 8 * mult_step
        mults = list(range(1, max_mult + 1, mult_step))
    else:
        mults = [0, .25, .50, .75] + [1, 2, 3]

    results_list = [None] * num_replications
    for replicate_index in range(num_replications):
        logger.info(f"replicate {replicate_index + 1} / {num_replications}")
        results_list[replicate_index] = run_one_replicate(mults,
                                                          optim_class,
                                                          early_stopping_params,
                                                          dataset_name,
                                                          optimize_criterion,
                                                          evaluate_criterion,
                                                          num_rows_train,
                                                          num_rows_test,
                                                          dim,
                                                          batch_size,
                                                          residualization_name,
                                                          initialization_name,
                                                          dtype)
    pickle_filename = experiment_ident + ".pkl"
    pickle_filedir = paths["results"]
    pickle_fullfilename = os.path.join(pickle_filedir, pickle_filename)
    logger.info(f"Saving results to {pickle_fullfilename}")
    with open(pickle_fullfilename, 'wb') as handle:
        pickle.dump(results_list,
                    handle,
                    protocol=pickle.HIGHEST_PROTOCOL)
    logger.info(f"Done")
