from typing import Generator
import pathlib
import json

import tensorboard as tb
import tensorflow as tf
import pandas as pd
import torch

from . import DIGEST_ANALYSIS_METHODS
from path_learning.utils import ROOT_DIR
from path_learning.utils.log import LOGDIR
from analysis.utils import ANALYSIS_LOGDIR
from .analyzer import Analyzer, pick_analysis
from path_learning.utils.log import get_timestring


def pick_digester(name):
    try:
        return DIGEST_ANALYSIS_METHODS[name]
    except KeyError:
        raise KeyError(f"Unknown plotting method '{name}'. Must be one of {list(PLOTTING_METHODS.keys())}")


class AnalyzerDigestResult(Analyzer):
    def __init__(self, experiment_name: str, **kwargs):
        # self.results_dir: pathlib.Path = results_dir
        self.results_dir = None
        self.columns = ("exp set name",
                        "exp seed",
                        "task uid",
                        "task dataset",
                        "step",
                        "result name",
                        "result"
                        )
        super().__init__(experiment_name, **kwargs)

    @staticmethod
    def create_analyzer_logdir(timestamp: str, exp_set_name: str) -> pathlib.Path:
        logdir = ANALYSIS_LOGDIR / f"{timestamp}_{exp_set_name}"
        logdir.mkdir(parents=True, exist_ok=False)
        return logdir

    def analyze(self):
        # Gather all event results in Dataframes
        df_tk = self.gather_tangent_kernel_results()
        df = self.gather_digest_results()
        # Run plotting or further analysis on gathered Dataframe
        # for exp_set in self.get_experiment_sets(check_complete_bool=False):
        timestamp = get_timestring()
        self.logdir = self.create_analyzer_logdir(timestamp, self.exp_set_name)
        self.save_kwargs()
        df.to_csv(self.logdir / "general_digest.csv", index=False)
        df_tk.to_csv(self.logdir / "tk_digest.csv", index=False)

        for method in self.analysis_methods:
            method_fct = pick_digester(method["name"])
            method_fct(df, self.logdir, method["name"])
            method_fct(df_tk, self.logdir, method["name"])

    def get_digest_kwargs(self, task_dir):
        digest_kwargs_file = [child for child in task_dir.iterdir() if
                              "digest_kwargs" in str(child) and child.is_file()]
        step_epoch_fraction = 0.5
        if len(digest_kwargs_file) > 0:
            with open(str(digest_kwargs_file[0]), "r") as fp:
                digest_config = json.load(fp)
            print(f"digest config: {digest_config}")
            step_epoch_fraction = digest_config.get("step_epoch_fraction", None)
            if step_epoch_fraction is not None:
                return step_epoch_fraction
            else:
                return 0.5
        return step_epoch_fraction

    def gather_tangent_kernel_results(self) -> pd.DataFrame:
        # Initialize empty dataframe
        df_tk = pd.DataFrame(columns=self.columns)

        start_path: pathlib.Path = ROOT_DIR / LOGDIR
        exp_set_dirs = [child for child in start_path.iterdir() if child.is_dir() and self.exp_set_name in child.stem]
        for exp_set_dir in exp_set_dirs:
            seed_dirs = [seed_dir for seed_dir in exp_set_dir.iterdir() if
                         seed_dir.is_dir() and "seed" in str(seed_dir)]
            for seed_dir in seed_dirs:
                for task_dir in [tdir for tdir in seed_dir.iterdir() if tdir.is_dir() and "task" in tdir.stem]:
                    for results_file in [child for child in task_dir.iterdir() if
                                         child.stem == "info" and child.is_file()]:
                        with open(str(results_file), "r") as fp:
                            result = json.load(fp)

                    step_epoch_fraction = self.get_digest_kwargs(task_dir)
                    tk_kernel_files = [child for child in task_dir.iterdir() if
                                            "tangent_kernel" in str(child) and child.is_file()]
                    if (task_dir / "saved_tensors").is_dir():
                        tk_kernel_files.extend([child for child in (task_dir / "saved_tensors").iterdir() if
                                           "tangent_kernel" in str(child) and child.is_file()])
                    tk_kernel_files = sorted(tk_kernel_files)

                    if len(tk_kernel_files) > 0:
                        tensor_0 = torch.load(tk_kernel_files[0])["tangent_kernel"]

                        distances = []
                        for i, kernel_file in enumerate(tk_kernel_files[1:]):
                            # Load other tensor
                            tensor_i = torch.load(kernel_file)["tangent_kernel"]
                            # Compute quantity of interest
                            distance = float(torch.norm(tensor_0 - tensor_i))
                            distances.append(distance)
                            new_row = (self.exp_set_name,
                                       str(seed_dir.stem),
                                       result["uid"],
                                       result["config"]["domains"]["target"]["dataset"],
                                       step_epoch_fraction * float(i),
                                       "tangent_kernel_distance",
                                       distance,
                                       )
                            print(f"new row: {new_row}")
                            df_tk.loc[len(df_tk)] = new_row

                            trace_dist = float(torch.norm(torch.trace(tensor_0) - torch.trace(tensor_i)))
                            new_row = (self.exp_set_name,
                                       str(seed_dir.stem),
                                       result["uid"],
                                       result["config"]["domains"]["target"]["dataset"],
                                       step_epoch_fraction * float(i),
                                       "tangent_kernel_trace_dist",
                                       trace_dist,
                                       )
                            print(f"new row: {new_row}")
                            df_tk.loc[len(df_tk)] = new_row
                        print(f"distances: {distances}")
        return df_tk

    def gather_digest_results(self) -> pd.DataFrame:
        # Initialize empty dataframe
        df = pd.DataFrame(columns=self.columns)

        start_path: pathlib.Path = ROOT_DIR / LOGDIR
        exp_set_dirs = [child for child in start_path.iterdir() if child.is_dir() and self.exp_set_name in child.stem]
        for exp_set_dir in exp_set_dirs:

            seed_dirs = [seed_dir for seed_dir in exp_set_dir.iterdir() if
                         seed_dir.is_dir() and "seed" in str(seed_dir)]

            for seed_dir in seed_dirs:
                for task_dir in [tdir for tdir in seed_dir.iterdir() if tdir.is_dir() and "task" in tdir.stem]:
                    for results_file in [child for child in task_dir.iterdir() if
                                         child.stem == "info" and child.is_file()]:
                        with open(str(results_file), "r") as fp:
                            result = json.load(fp)

                    step_epoch_fraction = self.get_digest_kwargs(task_dir)

                    for event_file in [child for child in task_dir.iterdir()
                                       if "events" in child.stem and child.is_file()]:

                        print(f"event file: {event_file}")
                        for e in tf.train.summary_iterator(str(event_file)):
                            if len(e.summary.value) > 0:
                                new_row = (self.exp_set_name,
                                           str(seed_dir.stem),
                                           result["uid"],
                                           result["config"]["domains"]["target"]["dataset"],
                                           step_epoch_fraction * float(e.step),
                                           e.summary.value[0].tag,
                                           e.summary.value[0].simple_value,
                                           )
                                df.loc[len(df)] = new_row

        print(f"df_task_results: {df}")
        return df