from typing import Optional, Callable, Any, Dict, Iterable
from copy import deepcopy
import pathlib
import json
import torch
from torch.utils.tensorboard import SummaryWriter

from .log import get_logger
from .digest_methods import compute_fisher_information, hess_vp, tk_kernel, \
    loss_and_grad_mag, compute_input_jacobian_change, \
    compute_theta_jacobian_change, power_method_rho_hessian, power_method_rho_hessian_normed, \
    individial_grad_correlation, OnlineDigestAnalysis, sensitivity_update, compute_loss, model_weight_norm, \
    elasticity, plasticity, tk_kernel_riemann_norm, elasticity_riemann_norm, plasticity_riemann_norm
from .digest_methods_optimal_path import optimal_path_information

import torch.distributions


DIGEST_METHODS = {
    "elasticity": elasticity,
    "plasticity": plasticity,
    "elasticity_riemann_norm": elasticity_riemann_norm,
    "plasticity_riemann_norm": plasticity_riemann_norm,
    "individual_grad_correlation": individial_grad_correlation,
    "fisher_information": compute_fisher_information,
    "hessian_grad_product": hess_vp,
    "tangent_kernel": tk_kernel,
    "tangent_kernel_riemann_norm": tk_kernel_riemann_norm,
    "loss_and_grad_mag": loss_and_grad_mag,
    "input_jacobian_change": compute_input_jacobian_change,
    "theta_jacobian_change": compute_theta_jacobian_change,
    "power_iteration_hessian": power_method_rho_hessian,
    "power_iteration_hessian_normed": power_method_rho_hessian_normed,
    "sensitivity_update": sensitivity_update,
    "train_error": compute_loss,
    "test_error": compute_loss,
    "weight_norm": model_weight_norm,
    "optimal_path_information": optimal_path_information
}


class TensorboardDigest:
    known_methods = tuple([key for key in DIGEST_METHODS])
    known_methods += ("approx_2nd_order_model", "approx_2nd_order_loss")

    def __init__(self, **config):
        self.config = self.validate_config(config)
        self.digest_frequency = self.config.get("digest_frequency", 800)
        self.digest_batch_level_freq = self.config.get("digest_batch_level_frequency", 10)
        self.batch_selection_frequency = self.config.get("batch_selection_frequency", 200)

        self.logger = get_logger("tensorboard_digest")
        self.logdir = config["logdir"]
        self.writer = SummaryWriter(self.logdir)
        self.step: int = 0
        self.step_epoch_fraction: float = None
        self.batch_level_step: int = 0
        self.loss: Callable = None
        self.dataloader: Iterable = None
        self.train_dataloader: Iterable = None
        self.test_dataloader: Iterable = None
        self.device = None
        self.batch_level_analyzer = OnlineDigestAnalysis(self.batch_selection_frequency)

    def create_tensor_logdir(self, logdir_name: str) -> pathlib.Path:
        logdir = self.logdir / logdir_name
        logdir.mkdir(parents=True, exist_ok=False)
        return logdir

    def save_kwargs(self) -> None:
        # Save important additional digest information here
        for key in self.config:
            if isinstance(self.config[key], pathlib.Path):
                self.config[key] = str(self.config[key])
        with open(str(self.logdir / "digest_kwargs.json"), "w") as fp:
            json.dump(self.config, fp, indent=4)

    def __enter__(self):
        self.create_tensor_logdir("saved_tensors")
        self.logger.info(f"Digest reset. Will save results to {self.logdir}")
        self.step_epoch_fraction = self.digest_frequency / len(self.train_dataloader)
        self.config["step_epoch_fraction"] = float(self.step_epoch_fraction)
        self.save_kwargs()

    def __exit__(self, exc_type=None, exc_val=None, exc_tb=None):
        self.logger.info(f"Exiting digest. Results were saved to {self.logdir}")
        self.logger.info("Closing Summary writer")
        self.step = 0
        self.batch_level_step = 0
        self.dataloader = None

    def __del__(self):
        try:
            self.writer.close()
        except Exception:  # pylint: disable=broad-except
            pass

    def write(self, tag, args, step, fn=None):
        if fn is None:
            value = args[0]
        else:
            value = fn(*args)
        self.writer.add_scalar(tag, value, step)

    def write_from_dict(self, method: str, output_dict: dict, step: int):
        for key in output_dict:
            if isinstance(output_dict[key], float) or isinstance(output_dict[key], int):
                print(f"{method + '-' + key}, value: {float(output_dict[key])} ")
                self.write(method + '-' + key, [float(output_dict[key])], step)
            elif isinstance(output_dict[key], torch.Tensor) or (key is "tangent_kernel" or "tangent_kernel_rnormed"):
                # Exception for tangent kernel to save tensor
                print(f"saving: {self.logdir} {key}_{step}.pt")
                torch.save({"tangent_kernel": output_dict[key].data},
                           self.logdir / "saved_tensors" / f'{key}_{step}.pt')
            else:
                ValueError("Output type is not correct. Please double check.")

    def digest(self, batch_idx: int, model: torch.nn.Module, epoch=None, variable_name=None):
        # Important: Turn off online-updates in batch normalization
        if batch_idx % self.digest_frequency == 0:
            methods_dict = self.config.get("methods", {})
            for method in methods_dict:
                model_copy = deepcopy(model)
                model_copy.eval()
                if method in ["approx_2nd_order_model", "approx_2nd_order_loss", "loss_and_grad_mag"]:
                    continue
                if method == "test_error":
                    model_copy.eval()
                    output_dict = DIGEST_METHODS[method](model_copy, self.loss, self.test_dataloader, self.device,
                                                         self.batch_selection_frequency, **methods_dict[method])
                elif method == "train_error":
                    model_copy.train()
                    output_dict = DIGEST_METHODS[method](model_copy, self.loss, self.train_dataloader, self.device,
                                                         self.batch_selection_frequency, **methods_dict[method])
                else:
                    methods_dict[method]["epoch"] = epoch
                    methods_dict[method]["variable_name"] = variable_name
                    output_dict = DIGEST_METHODS[method](model_copy, self.loss, self.train_dataloader, self.device,
                                                         self.batch_selection_frequency, **methods_dict[method])
                self.write_from_dict(method, output_dict, self.step)

            self.step += 1

    def digest_batch_level(self, model, data, target, batch_idx, kind: str):
        # Important: Turn off online-updates in batch normalization

        if batch_idx % self.digest_frequency == 0:
            for method in self.config.get("methods", []):
                model_copy = deepcopy(model)
                model_copy.eval()
                # Only select methods suitable for batch level analysis
                output_dict = None
                if "loss_and_grad_mag" == method:
                    model_copy.train()
                    output_dict = DIGEST_METHODS["loss_and_grad_mag"](model_copy, self.loss, data, target)
                elif "approx_2nd_order_model" == method:
                    output_dict = self.batch_level_analyzer.approx_2nd_order_model(model_copy, self.dataloader,
                                                                                   self.device, kind)
                elif "approx_2nd_order_loss" == method:
                    output_dict = self.batch_level_analyzer.approx_2nd_order_loss(model_copy, self.loss, self.dataloader,
                                                                                  self.device, kind)
                if kind == "after":
                    if output_dict is not None:
                        self.write_from_dict(method, output_dict, self.batch_level_step)
                    self.batch_level_step += 1
        model.train()

    def validate_config(self, config: Dict) -> Dict:
        # this is called during __init__ so that we notice problems with the config as early as possible
        for method in config.get("methods", []):
            if method not in self.known_methods:
                raise NotImplementedError(f"unknown digest method {method} must be one of {self.known_methods}")
        return config


