"""
Generators for 1D Poisson solutions.

The file "fem_solution_matrices.npy" needs to be in path. It contains FEM
solution operators for a range of grid sizes.
"""
from collections import defaultdict
from typing import Sequence
import numpy as np
from sympy import *
import tqdm


x, u_0, u_L, A, B = symbols("x u_0 u_L A B")
num_roots = 21
roots = symbols(f"r0:{num_roots}")


# Optimized for 0 BCs, which cuts down solution time by 80%
def solution_for_f(f: Expr) -> tuple[Expr, Expr]:
    u_nh = -integrate(integrate(f,x),x) + A*x + B
    homog = solve((u_nh.subs(x,0), u_nh.subs(x,1)), (A,B))
    u_sol = u_nh.subs(homog)
    return u_sol


def f_poly(n) -> Expr:
    return prod((x - root) for root in roots[:n])


def f_sin(n) -> Expr:
    return sum(roots[i]*sin((i + 1) * pi * x) for i in range(n))


def f_cos(n) -> Expr:
    return sum(roots[i]*cos((i + 1) * pi * x) for i in range(n))


def create_dataset_for_expression(
    f: Expr, u_sol: Expr, p: int,
    N_grid: int, n_samples: int = 10_000, root_min=-1, root_max=2
) -> dict[str, np.array]:
    eval_f = lambdify((x,) + roots[:p], f, "numpy")
    eval_u = lambdify((x,) + roots[:p], u_sol, "numpy")
    xx = np.linspace(0, 1, N_grid, dtype=np.float64)
    ff = np.empty((n_samples, N_grid), dtype=np.float64)
    uu = np.empty((n_samples, N_grid), dtype=np.float64)
    for i in range(n_samples):
        _r = (np.random.rand(p) * (root_max-root_min) + root_min).tolist()
        ff[i, :] = eval_f(xx, *_r)
        uu[i, :] = eval_u(xx, *_r)
        uu[i, 0] = 0
        uu[i, -1] = 0
    xx = np.repeat(xx.reshape(1, -1), n_samples, 0)  # Tile them for consistency
    return {"x": xx, "f": ff, "u": uu}


def create_fem_dataset(
    n_samples: int = 10_000, which_n_grid: Sequence[int] | None = (22,)
) -> dict[int, dict[str, np.array]]:
    # Note that this is where the file is loaded. Check here if you have a file not found error.
    fem_solution_matrices = np.load("fem_solution_matrices.npy", allow_pickle=True)
    train_data = defaultdict(dict)
    if which_n_grid is None:
        which_n_grid = fem_solution_matrices.keys()
    for ng in which_n_grid:
        Cm_fem = fem_solution_matrices[()][ng]
        f_values = np.random.normal(size=(ng, n_samples))
        u = Cm_fem @ f_values
        u[0, :] = 0
        u[-1, :] = 0
        train_data[ng]["fem"] = {
            "x": np.repeat(np.linspace(0, 1, ng).reshape(1, -1), n_samples, 0),
            "f": f_values.T,
            "u": u.T,
        }
    return train_data


def create_dataset_dict(
    n_samples: int = 10_000,
    root_spread=1,
    normalize=True,
    which_n_grid: Sequence[int] | None = (22,),
) -> dict[int, dict[str, np.array]]:
    train_data = create_fem_dataset(n_samples, which_n_grid)
    for ng in train_data.keys():
        for p in tqdm.tqdm(range(1, 9), desc=f"Making poly data for {ng}"):
            f = f_poly(p)
            u = solution_for_f(f)
            dd = create_dataset_for_expression(
                f, u, p, ng, 10_000, root_max=2*root_spread+1, root_min=-root_spread)
            train_data[ng][p] = dd
        for p in tqdm.tqdm(range(1, 9), desc=f"Making sin data for {ng}"):
            f = f_sin(p)
            u = solution_for_f(f)
            dd = create_dataset_for_expression(
                f, u, p, ng, 10_000, root_min=-1, root_max=1)
            train_data[ng][f's{p}'] = dd
        for p in tqdm.tqdm(range(1, 9), desc=f"Making cos data for {ng}"):
            f = f_cos(p)
            u = solution_for_f(f)
            dd = create_dataset_for_expression(
                f, u, p, ng, 10_000, root_min=-1, root_max=1)
            train_data[ng][f'c{p}'] = dd
    # Normalize all data
    if normalize:
        for ng in train_data.keys():
            # Normalize the datasets to have total norm of 1.0.
            unorm = 1.0
            for p in train_data[ng].keys():
                dd = train_data[ng][p]
                ratio = unorm / np.linalg.norm(dd["u"])
                print(
                    f"Normalizing {ng} {p} by {ratio} {unorm} {np.linalg.norm(dd['u'])}"
                )
                dd["f"] *= ratio
                dd["u"] *= ratio
                train_data[ng][p] = dd
    return train_data
