from dataclasses import dataclass, field
from enum import Enum
from os import makedirs
from os.path import join, exists, dirname
from glob import glob

import numpy as np
import torch

from .mlpsde import MlpSde


class Problem(Enum):
    lorenz = 1
    van_der_pol = 2


class Method(Enum):
    euler = 1
    adjoint = 2
    path_int = 3


@dataclass
class Config:
    data_dir = ".data"
    seed: int = None

    test_problem: Problem = Problem.lorenz
    method: Method = Method.euler

    device: str = field(
        default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu"
    )


@dataclass
class Data:
    times: torch.Tensor
    obs: torch.Tensor

    @property
    def dim(self):
        return self.obs.size(-1)

    @property
    def batches(self):
        return self.obs.size(-2)

    def to(self, d):
        return Data(self.times.to(d), self.obs.to(d))


def get_data(cfg: Config):
    fname = join(cfg.data_dir, f"{cfg.test_problem.name}.pt")
    if exists(fname):
        return torch.load(fname, map_location=cfg.device)

    times = torch.linspace(0, 10, 1000)
    if cfg.test_problem == Problem.lorenz:
        from .sde_systems import make_lorenz_paths as make
    elif cfg.test_problem == Problem.van_der_pol:
        from .sde_systems import make_van_der_pol_paths as make
    else:
        raise ValueError("Unimplemented problem")

    data = Data(times, make(times, batches=16))
    torch.save(data, fname)
    return data.to(cfg.device)


def main(cfg: Config):
    cfgname = join(cfg.data_dir, cfg.test_problem.name, cfg.method.name, "cfg.pk")
    makedirs(dirname(cfgname), exist_ok=True)
    torch.save(cfg, cfgname)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    train(cfg)


def train(cfg: Config):
    data = get_data(cfg)
    sde = MlpSde([data.dim, 32, 256, 32, data.dim]).to(cfg.device)

    from .training import train as tr

    fname = join(cfg.data_dir, cfg.test_problem.name, cfg.method.name)
    makedirs(fname, exist_ok=True)
    n = len(glob(fname + "/*"))
    fname = join(fname, f"{n}.pt")
    from .sde_systems import Lorenz

    lor = Lorenz()

    # data.obs = data.obs[:, :1, :]
    drift_norms = []
    sigmas = []

    def callback(step, res):
        print(
            f"step [{step}]; time [{res.clock_time[-1]:.1f}] mse [{res.mse[-1]:.2f}];\t\tloss [{res.loss[-1]:.2f}]"
        )
        with torch.no_grad():
            norm = (
                (lor.f(None, data.obs) - sde.f(None, data.obs)).square().sum(-1).mean()
            )
            drift_norms.append(norm)
            sigmas.append(sde.sigma.item())
            print(f"sigma [{sde.sigma.item()}]", f"l2diff [{norm}]")
        torch.save(
            {
                "stats": res,
                "drift_norms": drift_norms,
                "sigmas": sigmas,
                "checkpoint": sde,
            },
            fname,
        )

    res = tr(
        sde,
        data.times,
        data.obs,
        use_sdeint=cfg.method != Method.path_int,
        use_adjoint=cfg.method == Method.adjoint,
        callback=callback,
        n_samples=64,
    )

    callback(len(res.loss), res)


def sdeint_prior_weights(cfg: Config):
    fname = join(cfg.data_dir, "sdeint_prior.pt")
    if not exists(fname):
        data = toy_data(cfg)
        from torch.nn.functional import mse_loss
        from torchsde import sdeint

        sde = MlpSde([data.dim, 32, 256, 32, data.dim])
        batches = 5
        times = torch.linspace(data.times.min(), data.times.max(), 1000)

        with torch.no_grad():
            paths = sdeint(
                sde,
                data.obs[0].repeat(batches, 1),
                times,
            )
            weights = (
                -mse_loss(
                    paths[torch.bucketize(data.times, times, right=True) - 1],
                    data.obs.repeat(1, batches, 1),
                    reduction="none",
                )
                .sum(-1)
                .sum(0)
            )
            weights -= weights.min()
            weights /= weights.max()

        obj = {
            "sde": sde,
            "paths": paths.numpy(),
            "weights": weights.numpy(),
            "obs": data.obs.numpy(),
        }
        # torch.save(obj, fname)
    else:
        obj = torch.load(fname)
        sde = obj["sde"]

    from .plotting import PriorPathWeights

    PriorPathWeights(
        torch.no_grad()(lambda x: sde.f(None, torch.FloatTensor(x)).numpy()),
        obj["paths"],
        obj["weights"],
        obj["obs"],
    ).plot("randwalk")


def piis_prior_weights(cfg: Config):
    fname = join(cfg.data_dir, "piis_prior.pt")
    if not exists(fname):
        data = toy_data(cfg)
        from .piis import PathInt
        from .brownianbridge import BrownianBridge

        sde = MlpSde([data.dim, 32, 256, 32, data.dim])
        batches = 5
        times = torch.linspace(data.times.min(), data.times.max(), 1000)

        with torch.no_grad():
            paths, weights = PathInt(
                BrownianBridge(
                    times,
                    data.times,
                    data.obs.squeeze(1),
                    sde.sigma,
                    batch_shape=torch.Size([batches]),
                ),
                sde.f,
            ).generate_samples()
            weights -= weights.min()
            weights /= weights.max()

        obj = {
            "sde": sde,
            "paths": paths.transpose(0, 1).numpy(),
            "weights": weights.numpy(),
            "obs": data.obs.numpy(),
        }
        # torch.save(obj, fname)
    else:
        obj = torch.load(fname)
        sde = obj["sde"]

    from .plotting import PriorPathWeights

    PriorPathWeights(
        torch.no_grad()(lambda x: sde.f(None, torch.FloatTensor(x)).numpy()),
        obj["paths"],
        obj["weights"],
        obj["obs"],
    ).plot("bridge")


def toy_data(*args):
    time = torch.linspace(0, 1, 10)
    x = time * 3
    y = 3 * torch.sin(time * 2 * np.pi)

    return Data(
        time,
        torch.stack(
            (x.unsqueeze(-1), y.unsqueeze(-1)),
            dim=-1,
        ),
    )


if __name__ == "__main__":
    from sys import argv
    if len(argv) < 2:
        print("USAGE:")
        print("python -m pathint METHOD")
        print()
        print("where METHOD is one of:")
        print("  adj:    use the adjoint method")
        print("   pi:    use path integral importance sampling")
        exit(-1)
    if argv[1].lower() == "pi":
        print("Running path integral importance sampling")
        method = Method.path_int
    elif argv[1].lower() == "adj":
        print("Running adjoint training method")
        method = Method.adjoint
    else:
        print("Unknown method!")
        exit(-2)

    main(
        Config(
            seed=1,
            test_problem=Problem.lorenz,
            method=method,
            # device=argv[2],
        )
    )

    # np.random.seed(1)
    # torch.manual_seed(1)
    # sdeint_prior_weights(
    #     Config(
    #         seed=1,
    #         test_problem=Problem.lorenz,
    #         method=Method.path_int,
    #     )
    # )
    # np.random.seed(1)
    # torch.manual_seed(1)
    # piis_prior_weights(
    #     Config(
    #         seed=1,
    #         test_problem=Problem.lorenz,
    #         method=Method.path_int,
    #     )
    # )
