import os
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import logging
from typing import Callable

import json


@dataclass
class ClimateConfig:
    run_id: int

    # File config
    data_folder: str
    output_folder: str
    save_predictions: bool

    # Data config
    target: str
    horizon: int
    time_delay: int

    # Model config
    model_name: str
    width: int
    activation: str
    regularization: float
    seed: int
   

@dataclass(kw_only=True)
class ClimateOutput:
    predictions: dict[str, pd.DataFrame]
    metrics: dict[str, pd.Series]
    training_time: float 
    prediction_times: dict[str, float]

    def aggregate_metrics(self, aggr_fn: Callable = np.mean):
        aggregated = {label: aggr_fn(np.abs(split_metrics)) for
                      label, split_metrics in self.metrics.items()}
        return aggregated
    
    def get_output_dict(self):
        return {
            "metrics": self.aggregate_metrics(),
            "prediction_times": self.prediction_times,
            "training_time": self.training_time
        }
    

class ClimateSplit:

    _df: pd.DataFrame
    _timestamps: pd.Series

    def __init__(self, csv_filename: str, size_limit: int | None = None):
        df = pd.read_csv(csv_filename, index_col=0)
        timestamps = df.pop("Date Time")
        timestamps = pd.to_datetime(timestamps, format="%Y-%m-%d %H:%M:%S")

        self._df = df.iloc[:size_limit]
        self._timestamps = timestamps.iloc[:size_limit]

    @property
    def data(self):
        return self._df
    
    @property
    def timestamps(self):
        return self._timestamps
    

class Normalizer:
    _mean: pd.Series = None
    _std: pd.Series = None

    def fit(self, df: pd.DataFrame):
        self._mean = df.mean()
        self._std = df.std()

        zero_mask = self._std == 0
        if np.any(zero_mask):
            raise ValueError("Got zero std for columns:"
                             f"{self._std.index[zero_mask].to_list()}.") 

    def transform(self, df: pd.DataFrame, inverse=False) -> pd.DataFrame:
        if inverse:
            return df * self._std + self._mean
        else:
            return (df - self._mean) / self._std
    

def plot_results(splits: dict[str, ClimateSplit], output: ClimateOutput, target: str) -> plt.Figure:
    predictions = output.predictions
    metrics = output.metrics
    labels = predictions.keys()


    fig, axes = plt.subplots(len(predictions), 2,
                             figsize=(12, 4 * len(predictions)),
                             sharey="col", squeeze=False)
    
    for (ax_pred, ax_metric), label in zip(axes, labels):
        split_values = splits[label].data[target]
        timestamps = splits[label].timestamps
        split_predictions = predictions[label][target]
        time_delay = len(split_values) - len(split_predictions)

        # Plot predictions.
        ax_pred.plot(timestamps, split_values, label="true")
        ax_pred.plot(timestamps[time_delay:], split_predictions, linestyle="--", label="prediction")

        # Plot errors.
        ax_metric.plot(timestamps[time_delay:], metrics[label], c="r", label="error")

        for ax in [ax_pred, ax_metric]:
            ax.legend(loc="center left", bbox_to_anchor=[1.01, 0.5])
            ax.set_xticks(ax.get_xticks())
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
            ax.set_title(f"split: {label}")

    fig.tight_layout()
    return fig

def save_results(
    fig: plt.Figure,
    output: ClimateOutput,
    output_folder: str,
    base_filename: str,
    save_predictions: bool = False
):
    
    base_path = os.path.join(output_folder, base_filename)

    fig.savefig(f"{base_path}.pdf")
    with open(f"{base_path}.json", "w") as fout:
        json.dump(output.get_output_dict(), fout, indent=4)
    
    if save_predictions:
        for label, prediction in output.predictions.items():
            np.save(f"{base_path}_{label}.npy", prediction)


def get_logger(filename: str, name: str = "climate_logger") -> logging.Logger:
    logger = logging.getLogger(name)
    handler = logging.FileHandler(filename, mode="w")
    formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)
    return logger