from dataclasses import dataclass
from typing import List, Dict

import numpy as np
from matplotlib import pyplot as plt, rcParams
from matplotlib.pylab import cm
import torch

from .training import TrainingResult
from .__main__ import Data  # yuck... TODO refactor this rubbish!!


@dataclass
class LearningCurve:
    euler: List[TrainingResult]
    adjoint: List[TrainingResult]
    piis: List[TrainingResult]

    def plot(self):
        pass


@dataclass
class PriorPathWeights:
    drift: callable
    paths: np.ndarray
    weights: np.ndarray
    obs: np.ndarray

    def plot(self, name):
        ax: plt.Axes
        fig, ax = plt.subplots()

        minx, miny = [
            min(np.min(self.paths[..., i]), np.min(self.obs[..., i])) for i in (0, 1)
        ]
        maxx, maxy = [
            max(np.max(self.paths[..., i]), np.max(self.obs[..., i])) for i in (0, 1)
        ]
        Y, X = np.mgrid[miny:maxy:10j, minx:maxx:10j]
        xy = np.stack((X, Y), axis=-1)
        uv = self.drift(xy)

        plt.streamplot(X, Y, uv[..., 0], uv[..., 1], color="#BBBBBB", zorder=-1)
        colors = cm.ocean(np.linspace(0.3, 0.7, self.paths.shape[1]))
        for i in range(self.paths.shape[1]):
            ax.plot(
                self.paths[:, i, 0],
                self.paths[:, i, 1],
                color=colors[i],
                zorder=0,
            )
        ax.scatter(
            self.obs[..., 0],
            self.obs[..., 1],
            color="#DD3030",
            zorder=1,
            label="observations",
            marker="X",
        )

        ax.set_xticks([])
        ax.set_yticks([])

        plt.legend(loc="upper right")
        # plt.show()
        plt.savefig(f".data/figures/{name}.eps", format="eps", bbox_inches="tight")
        plt.close(fig)


def plot_lorenz_exp():
    from os import makedirs

    makedirs(".data/figures", exist_ok=True)

    adj = torch.load(".data/lorenz/adjoint/1.pt", map_location="cpu")
    pat = torch.load(".data/lorenz/path_int/1.pt", map_location="cpu")
    obs = torch.load(".data/lorenz.pt").obs
    from .sde_systems import Lorenz

    rel = Lorenz().f(None, obs).square().sum(-1).mean().item()

    adj_time = np.array(adj["stats"].clock_time[::10]) / 60
    pat_time = np.array(pat["stats"].clock_time[::10]) / 60
    print(adj_time[-1] / pat_time[-1])

    rcParams["font.size"] = 12

    color1 = "#ef233c"
    color2 = "#023e8a"
    ax: plt.Axes
    fig, ax = plt.subplots()
    ax.set_xlabel("time (minutes)")
    ax.set_ylabel("err")
    ax.set_xlim(0, 4 * pat_time[-1])
    ax.set_ylim(-0.2, 4)
    ax.plot(
        pat_time,
        np.array(pat["drift_norms"][:-1]) / rel,
        color=color1,
        # linestyle="--",
        label="path int",
    )
    ax.plot(
        adj_time,
        np.array(adj["drift_norms"][:-1]) / rel,
        color=color2,
        # linestyle="solid",
        label="adjoint",
    )
    plt.legend()
    plt.savefig(".data/figures/lorenz-drift.eps", format="eps", bbox_inches="tight")
    plt.close(fig)

    # color = "#0096C7"
    fig, ax = plt.subplots()
    ax.set_xlabel("time (minutes)")
    ax.set_ylabel("data mse")
    ax.set_xlim(0, 4 * pat_time[-1])
    # ax.set_ylim(bottom=0)
    ax.plot(
        pat_time,
        pat["stats"].mse,
        color=color1,
        # linestyle="--",
        label="path int",
    )
    ax.plot(
        adj_time,
        adj["stats"].mse,
        color=color2,
        # linestyle="solid",
        label="adjoint",
    )
    plt.legend()

    plt.savefig(".data/figures/lorenz-mse.eps", format="eps", bbox_inches="tight")
    plt.close(fig)


if __name__ == "__main__":
    plot_lorenz_exp()
