import torch

from metatensor.torch.atomistic import load_atomistic_model
from metatrain.utils.data import DiskDataset
from torch.utils.data import Subset
import numpy as np
from metatrain.utils.data import collate_fn  # noqa: E402
from metatensor.torch.atomistic import (  # noqa: E402
    ModelOutput,
    ModelEvaluationOptions
)
import tqdm
import matplotlib.pyplot as plt
import sys

import ase.io
import copy
from skipmd.ase.velocity_verlet import _convert_atoms_to_system
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

plt.rcParams["font.size"] = 14


time_steps = int(sys.argv[1])

model = load_atomistic_model(f"../models/water_{time_steps//4}fs_uq.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
model = model.to(device)

dataset = DiskDataset(f"../../water_train/data/water_scaled_{time_steps}.zip")
training_indices = np.genfromtxt(
    f"indices_{time_steps}/training.txt", dtype=int
)
validation_indices = np.genfromtxt(
    f"indices_{time_steps}/validation.txt", dtype=int
)
test_indices = np.genfromtxt(
    f"indices_{time_steps}/test.txt", dtype=int
)

training_dataset = Subset(dataset, training_indices)
validation_dataset = Subset(dataset, validation_indices)
test_dataset = Subset(dataset, test_indices)

training_dataloader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
)

validation_dataloader = torch.utils.data.DataLoader(
    validation_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
)

requested_outputs = {
    f"mtt::delta_{time_steps}_q": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::p_{time_steps}": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aux::delta_{time_steps}_q_last_layer_features": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aux::p_{time_steps}_last_layer_features": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aleatoric_delta_{time_steps}_q": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aleatoric_p_{time_steps}": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
}
eval_options = ModelEvaluationOptions(
    length_unit="angstrom",
    outputs=requested_outputs
)
model.capabilities().outputs = requested_outputs

covariance_q = torch.zeros(256, 256, dtype=dtype, device=device)
covariance_p = torch.zeros(256, 256, dtype=dtype, device=device)
for batch in tqdm.tqdm(training_dataloader):
    systems, _ = batch
    systems = [system.to(device=device, dtype=dtype) for system in systems]
    outputs = model(systems, eval_options, check_consistency=False)
    features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    covariance_q += features_q.T @ features_q
    covariance_p += features_p.T @ features_p
inverse_covariance_q_scale = torch.mean(torch.abs(torch.diag(covariance_q)))
inverse_covariance_p_scale = torch.mean(torch.abs(torch.diag(covariance_p)))
inverse_covariance_q = torch.linalg.inv(covariance_q + 1e-4 * inverse_covariance_q_scale * torch.eye(256, dtype=dtype, device=device))
inverse_covariance_p = torch.linalg.inv(covariance_p + 1e-4 * inverse_covariance_p_scale * torch.eye(256, dtype=dtype, device=device))

p_residuals = []
q_residuals = []
p_uncertainties_aleatoric = []
q_uncertainties_aleatoric = []
p_uncertainties_epistemic = []
q_uncertainties_epistemic = []
for batch in tqdm.tqdm(validation_dataloader):
    systems, targets = batch
    systems = [system.to(device=device, dtype=dtype) for system in systems]
    targets = {
        name: target.to(device=device, dtype=dtype)
        for name, target in targets.items()
    }
    outputs = model(systems, eval_options, check_consistency=False)
    p_residuals.append(outputs[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_residuals.append(outputs[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    p_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    ll_features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    ll_features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    p_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_p, inverse_covariance_p, ll_features_p).cpu().numpy())
    q_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_q, inverse_covariance_q, ll_features_q).cpu().numpy())

p_residuals = np.concatenate(p_residuals, axis=0)
q_residuals = np.concatenate(q_residuals, axis=0)

p_uncertainties_aleatoric = np.concatenate(p_uncertainties_aleatoric, axis=0)
q_uncertainties_aleatoric = np.concatenate(q_uncertainties_aleatoric, axis=0)

p_uncertainties_epistemic = np.concatenate(p_uncertainties_epistemic, axis=0)
q_uncertainties_epistemic = np.concatenate(q_uncertainties_epistemic, axis=0)

p_squared_residuals = np.sum(p_residuals**2, axis=-1)
q_squared_residuals = np.sum(q_residuals**2, axis=-1)

p_epistemic_scale = np.mean(p_squared_residuals / p_uncertainties_epistemic)
q_epistemic_scale = np.mean(q_squared_residuals / q_uncertainties_epistemic)

p_aleatoric_scale = np.mean(p_squared_residuals / p_uncertainties_aleatoric)
q_aleatoric_scale = np.mean(q_squared_residuals / q_uncertainties_aleatoric)

p_residuals = []
q_residuals = []
p_uncertainties_aleatoric = []
q_uncertainties_aleatoric = []
p_uncertainties_epistemic = []
q_uncertainties_epistemic = []
for batch in tqdm.tqdm(test_dataloader):
    systems, targets = batch
    systems = [system.to(device=device, dtype=dtype) for system in systems]
    targets = {
        name: target.to(device=device, dtype=dtype)
        for name, target in targets.items()
    }
    outputs = model(systems, eval_options, check_consistency=False)
    p_residuals.append(outputs[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_residuals.append(outputs[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    p_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    ll_features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    ll_features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    p_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_p, inverse_covariance_p, ll_features_p).cpu().numpy())
    q_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_q, inverse_covariance_q, ll_features_q).cpu().numpy())

p_residuals = np.concatenate(p_residuals, axis=0)
q_residuals = np.concatenate(q_residuals, axis=0)

p_uncertainties_aleatoric = np.concatenate(p_uncertainties_aleatoric, axis=0) * p_aleatoric_scale
q_uncertainties_aleatoric = np.concatenate(q_uncertainties_aleatoric, axis=0) * q_aleatoric_scale

p_uncertainties_epistemic = np.concatenate(p_uncertainties_epistemic, axis=0) * p_epistemic_scale
q_uncertainties_epistemic = np.concatenate(q_uncertainties_epistemic, axis=0) * q_epistemic_scale

p_squared_residuals = np.sum(p_residuals**2, axis=-1)
q_squared_residuals = np.sum(q_residuals**2, axis=-1)

np.save(f"p_squared_residuals_{time_steps//4}.npy", p_squared_residuals)
np.save(f"q_squared_residuals_{time_steps//4}.npy", q_squared_residuals)
np.save(f"p_uncertainties_aleatoric_{time_steps//4}.npy", p_uncertainties_aleatoric)
np.save(f"q_uncertainties_aleatoric_{time_steps//4}.npy", q_uncertainties_aleatoric)
np.save(f"p_uncertainties_epistemic_{time_steps//4}.npy", p_uncertainties_epistemic)
np.save(f"q_uncertainties_epistemic_{time_steps//4}.npy", q_uncertainties_epistemic)

















sigma = 1.0
func = lambda x: x * np.exp(-x**2/(2*sigma**2)) * 1.0/(sigma*np.sqrt(2*np.pi))
x = np.geomspace(0.01, 10, 10000)

from scipy.optimize import root_scalar

def pdf(x, sigma):
    return x * np.exp(-x**2/(2*sigma**2)) * 1.0/(sigma*np.sqrt(2*np.pi))

def find_where_pdf_is_c(c, sigma):
    # Finds the two values of x where the pdf is equal to c
    mode_value = pdf(sigma, sigma)
    if c > mode_value:
        raise ValueError("c must be less than mode_value")
    where_below_mode = root_scalar(lambda x: pdf(x, sigma) - c, bracket=[0, sigma]).root
    where_above_mode = root_scalar(lambda x: pdf(x, sigma) - c, bracket=[sigma, 100]).root
    return where_below_mode, where_above_mode

def pdf_integral(sigma, c):
    # Calculates the integral (analytical) of the pdf from x1 to x2,
    # where x1 and x2 are the two values of x where the pdf is equal to c
    x1, x2 = find_where_pdf_is_c(c, sigma)
    return np.exp(-x1**2/(2*sigma**2)) - np.exp(-x2**2/(2*sigma**2))

def find_fraction(sigma, fraction):
    # Finds the value of c where the integral of the pdf from x1 to x2 is equal to fraction,
    # where x1 and x2 are the two values of x where the pdf is equal to c
    mode_value = pdf(sigma, sigma)
    return root_scalar(lambda x: pdf_integral(sigma, x) - fraction, x0=mode_value-0.01, x1=mode_value-0.02).root

from scipy.stats import norm

desired_fractions = [
    norm.cdf(1, 0.0, 1.0) - norm.cdf(-1, 0.0, 1.0),  # 1 sigma
    norm.cdf(2, 0.0, 1.0) - norm.cdf(-2, 0.0, 1.0),  # 2 sigma
    norm.cdf(3, 0.0, 1.0) - norm.cdf(-3, 0.0, 1.0),  # 3 sigma
]
# print(desired_fractions)

sigmas = np.linspace(2e-5, 5e0, 5)

lower_bounds = []
upper_bounds = []
for desired_fraction in desired_fractions:
    lower_bounds.append([])
    upper_bounds.append([])
    for sigma in sigmas:
        isoline_value = find_fraction(sigma, desired_fraction)
        x1, x2 = find_where_pdf_is_c(isoline_value, sigma)
        lower_bounds[-1].append(x1)
        upper_bounds[-1].append(x2)

    additional_sigma = 100.0
    lower_bounds[-1].append(
        lower_bounds[-1][-1] + (lower_bounds[-1][-1] - lower_bounds[-1][-2])/(sigmas[-1] - sigmas[-2]) * additional_sigma
    )
    upper_bounds[-1].append(
        upper_bounds[-1][-1] + (upper_bounds[-1][-1] - upper_bounds[-1][-2])/(sigmas[-1] - sigmas[-2]) * additional_sigma
    )

    lower_bounds[-1] = np.array(lower_bounds[-1])
    upper_bounds[-1] = np.array(upper_bounds[-1])

sigmas = np.concatenate([sigmas, np.array([100.0])])

#############################################################

plt.plot(p_uncertainties_aleatoric, p_squared_residuals, ".", markersize=1.0, rasterized=True)
plt.plot(sigmas, sigmas, color="k")
for l, u in zip(lower_bounds, upper_bounds):
    plt.plot(sigmas, l, color="k", linewidth=0.5)
    plt.plot(sigmas, u, color="k", linewidth=0.5)
plt.xlim(np.min(p_uncertainties_aleatoric), np.max(p_uncertainties_aleatoric))
plt.ylim(np.min(p_squared_residuals), np.max(p_squared_residuals))
plt.xscale("log")
plt.yscale("log")
plt.title(f"Momenta, aleatoric, ({time_steps//4} fs)")
plt.xlabel("Uncertainty (eV)")
plt.ylabel("Error (eV)")
plt.tight_layout()
plt.savefig("p_aleatoric.pdf", dpi=300)
plt.clf()

plt.plot(p_uncertainties_epistemic, p_squared_residuals, ".", markersize=1.0, rasterized=True)
plt.plot(sigmas, sigmas, color="k")
for l, u in zip(lower_bounds, upper_bounds):
    plt.plot(sigmas, l, color="k", linewidth=0.5)
    plt.plot(sigmas, u, color="k", linewidth=0.5)
plt.xlim(np.min(p_uncertainties_epistemic), np.max(p_uncertainties_epistemic))
plt.ylim(np.min(p_squared_residuals), np.max(p_squared_residuals))
plt.xscale("log")
plt.yscale("log")
plt.title(f"Momenta, epistemic, ({time_steps//4} fs)")
plt.xlabel("Uncertainty (eV)")
plt.ylabel("Error (eV)")
plt.tight_layout()
plt.savefig("p_epistemic.pdf", dpi=300)
plt.clf()

plt.plot(q_uncertainties_aleatoric, q_squared_residuals, ".", markersize=1.0, rasterized=True)
plt.plot(sigmas, sigmas, color="k")
for l, u in zip(lower_bounds, upper_bounds):
    plt.plot(sigmas, l, color="k", linewidth=0.5)
    plt.plot(sigmas, u, color="k", linewidth=0.5)
plt.xlim(np.min(q_uncertainties_aleatoric), np.max(q_uncertainties_aleatoric))
plt.ylim(np.min(q_squared_residuals), np.max(q_squared_residuals))
plt.xscale("log")
plt.yscale("log")
plt.title(f"Positions, aleatoric, ({time_steps//4} fs)")
plt.xlabel("Uncertainty (eV)")
plt.ylabel("Error (eV)")
plt.tight_layout()
plt.savefig("q_aleatoric.pdf", dpi=300)
plt.clf()

plt.plot(q_uncertainties_epistemic, q_squared_residuals, ".", markersize=1.0, rasterized=True)
plt.plot(sigmas, sigmas, color="k")
for l, u in zip(lower_bounds, upper_bounds):
    plt.plot(sigmas, l, color="k", linewidth=0.5)
    plt.plot(sigmas, u, color="k", linewidth=0.5)
plt.xlim(np.min(q_uncertainties_epistemic), np.max(q_uncertainties_epistemic))
plt.ylim(np.min(q_squared_residuals), np.max(q_squared_residuals))
plt.xscale("log")
plt.yscale("log")
plt.title(f"Positions, epistemic, ({time_steps//4} fs)")
plt.xlabel("Uncertainty (eV)")
plt.ylabel("Squared residual (eV)")
plt.tight_layout()
plt.savefig("q_epistemic.pdf", dpi=300)
plt.clf()


# Funny structures:
structure = ase.io.read(f"../water/water.xyz")
cell_factors = np.linspace(0.5, 1.5, 11)
q_uncertainties_aleatoric = []
q_uncertainties_epistemic = []
p_uncertainties_aleatoric = []
p_uncertainties_epistemic = []
for cell_factor in cell_factors:
    s = copy.deepcopy(structure)
    s.set_cell(s.get_cell() * cell_factor, scale_atoms=True)
    funny_system = _convert_atoms_to_system(s, dtype=dtype, device=device)
    funny_system = get_system_with_neighbor_lists(funny_system, model.requested_neighbor_lists())
    outputs = model([funny_system], eval_options, check_consistency=False)
    p_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_p_{time_steps}"].block().values.squeeze(-1).cpu().numpy().mean())
    q_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy().mean())
    ll_features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    ll_features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    p_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_p, inverse_covariance_p, ll_features_p).cpu().numpy().mean())
    q_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_q, inverse_covariance_q, ll_features_q).cpu().numpy().mean())

p_uncertainties_aleatoric = np.array(p_uncertainties_aleatoric) * p_aleatoric_scale
q_uncertainties_aleatoric = np.array(q_uncertainties_aleatoric) * q_aleatoric_scale
p_uncertainties_epistemic = np.array(p_uncertainties_epistemic) * p_epistemic_scale
q_uncertainties_epistemic = np.array(q_uncertainties_epistemic) * q_epistemic_scale

np.save(f"p_uncertainties_aleatoric_funny_{time_steps//4}.npy", p_uncertainties_aleatoric)
np.save(f"q_uncertainties_aleatoric_funny_{time_steps//4}.npy", q_uncertainties_aleatoric)
np.save(f"p_uncertainties_epistemic_funny_{time_steps//4}.npy", p_uncertainties_epistemic)
np.save(f"q_uncertainties_epistemic_funny_{time_steps//4}.npy", q_uncertainties_epistemic)

plt.plot(cell_factors, p_uncertainties_aleatoric, "o", rasterized=True, label="aleatoric")
plt.plot(cell_factors, p_uncertainties_epistemic, "o", rasterized=True, label="epistemic")
# plt.yscale("log")
plt.legend()
plt.savefig("p_uncertainties_funny.pdf", dpi=300)
plt.clf()

plt.plot(cell_factors, q_uncertainties_aleatoric, "o", rasterized=True, label="aleatoric")
plt.plot(cell_factors, q_uncertainties_epistemic, "o", rasterized=True, label="epistemic")
# plt.yscale("log")
plt.legend()
plt.savefig("q_uncertainties_funny.pdf", dpi=300)
plt.clf()
