import json
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


@dataclass
class DataConfig:
    run_id: int

    # File config
    data_folder: str
    output_folder: str
    save_predictions: bool

    # Data config
    target: str
    horizon: int
    time_delay: int

    # Model config
    model_name: str
    width: int
    activation: str
    regularization: float
    seed: int


@dataclass
class DataOutput:
    predictions: dict[str, pd.DataFrame]
    metrics: dict[str, pd.Series]
    training_time: float
    prediction_times: dict[str, float]

    def aggregate_metrics(self, aggr_fn: Callable = np.mean):
        aggregated = {
            label: aggr_fn(np.abs(split_metrics))
            for label, split_metrics in self.metrics.items()
        }
        return aggregated

    def get_output_dict(self):
        return {
            "metrics": self.aggregate_metrics(),
            "prediction_times": self.prediction_times,
            "training_time": self.training_time,
        }


class DataSplit:
    _df: pd.DataFrame
    _timestamps: pd.Series

    def __init__(self, csv_filename: str, size_limit: Union[int, None] = None):
        df = pd.read_csv(csv_filename, index_col=0)
        timestamps = df.pop("Date Time")
        timestamps = pd.to_datetime(timestamps)

        self._df = df.iloc[:size_limit]
        self._timestamps = timestamps.iloc[:size_limit]

    @property
    def data(self):
        return self._df

    @property
    def timestamps(self):
        return self._timestamps


class Normalizer:
    _mean: pd.Series = None
    _std: pd.Series = None

    def fit(self, df: pd.DataFrame):
        self._mean = df.mean()
        self._std = df.std()

        zero_mask = self._std == 0
        if np.any(zero_mask):
            raise ValueError(
                "Got zero std for columns:" f"{self._std.index[zero_mask].to_list()}."
            )

    def transform(self, df: pd.DataFrame, inverse=False) -> pd.DataFrame:
        if inverse:
            return df * self._std + self._mean
        else:
            return (df - self._mean) / self._std


def plot_results(
    splits: dict[str, DataSplit], output: DataOutput, target: str
) -> plt.Figure:
    predictions = output.predictions
    metrics = output.metrics
    labels = predictions.keys()

    fig, axes = plt.subplots(
        len(predictions),
        2,
        figsize=(12, 4 * len(predictions)),
        sharey="col",
        squeeze=False,
    )

    for (ax_pred, ax_metric), label in zip(axes, labels):
        split_values = splits[label].data[target]
        timestamps = splits[label].timestamps
        split_predictions = predictions[label][target]
        time_delay = len(split_values) - len(split_predictions)

        # Plot predictions.
        ax_pred.plot(timestamps, split_values, label="true")
        ax_pred.plot(
            timestamps[time_delay:],
            split_predictions,
            linestyle="--",
            label="prediction",
        )

        # Plot errors.
        ax_metric.plot(timestamps[time_delay:], metrics[label], c="r", label="error")

        for ax in [ax_pred, ax_metric]:
            ax.legend(loc="center left", bbox_to_anchor=[1.01, 0.5])
            ax.set_xticks(ax.get_xticks())
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
            ax.set_title(f"split: {label}")

    fig.tight_layout()
    return fig


def save_results(
    fig: plt.Figure,
    output: DataOutput,
    output_folder: str,
    base_filename: str,
    save_predictions: bool = False,
):
    base_path = os.path.join(output_folder, base_filename)
    Path(base_path).mkdir(parents=True, exist_ok=True)

    fig.savefig(f"{base_path}.pdf")
    with open(f"{base_path}.json", "w") as fout:
        json.dump(output.get_output_dict(), fout, indent=4)

    if save_predictions:
        for label, prediction in output.predictions.items():
            np.save(f"{base_path}_{label}.npy", prediction)


def get_logger(filename: str, name: str = "data_logger") -> logging.Logger:
    logger = logging.getLogger(name)
    handler = logging.FileHandler(filename, mode="w")
    formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)
    return logger
