import torch

from metatensor.torch.atomistic import load_atomistic_model
from metatrain.utils.data import DiskDataset
from torch.utils.data import Subset
import numpy as np
from metatrain.utils.data import collate_fn  # noqa: E402
from metatensor.torch.atomistic import (  # noqa: E402
    ModelOutput,
    ModelEvaluationOptions
)
import tqdm
import matplotlib.pyplot as plt
import sys

import ase.io
import copy
from skipmd.ase.velocity_verlet import _convert_atoms_to_system
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

plt.rcParams["font.size"] = 14


time_steps = int(sys.argv[1])

model = load_atomistic_model(f"../models/water_{time_steps//4}fs_uq.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
model = model.to(device)

dataset = DiskDataset(f"../../water_train/data/water_scaled_{time_steps}.zip")
training_indices = np.genfromtxt(
    f"indices_{time_steps}/training.txt", dtype=int
)
validation_indices = np.genfromtxt(
    f"indices_{time_steps}/validation.txt", dtype=int
)
test_indices = np.genfromtxt(
    f"indices_{time_steps}/test.txt", dtype=int
)

training_dataset = Subset(dataset, training_indices)
validation_dataset = Subset(dataset, validation_indices)
test_dataset = Subset(dataset, test_indices)

training_dataloader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
)

validation_dataloader = torch.utils.data.DataLoader(
    validation_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
)

requested_outputs = {
    f"mtt::delta_{time_steps}_q": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::p_{time_steps}": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aux::delta_{time_steps}_q_last_layer_features": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aux::p_{time_steps}_last_layer_features": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aleatoric_delta_{time_steps}_q": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
    f"mtt::aleatoric_p_{time_steps}": ModelOutput(
        quantity="",
        unit="",
        per_atom=True,
    ),
}
eval_options = ModelEvaluationOptions(
    length_unit="angstrom",
    outputs=requested_outputs
)
model.capabilities().outputs = requested_outputs

covariance_q = torch.zeros(256, 256, dtype=dtype, device=device)
covariance_p = torch.zeros(256, 256, dtype=dtype, device=device)
for batch in tqdm.tqdm(validation_dataloader):
    systems, _ = batch
    systems = [system.to(device=device, dtype=dtype) for system in systems]
    outputs = model(systems, eval_options, check_consistency=False)
    features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    covariance_q += features_q.T @ features_q
    covariance_p += features_p.T @ features_p
inverse_covariance_q_scale = torch.mean(torch.abs(torch.diag(covariance_q)))
inverse_covariance_p_scale = torch.mean(torch.abs(torch.diag(covariance_p)))
inverse_covariance_q = torch.linalg.inv(covariance_q + 1e-4 * inverse_covariance_q_scale * torch.eye(256, dtype=dtype, device=device))
inverse_covariance_p = torch.linalg.inv(covariance_p + 1e-4 * inverse_covariance_p_scale * torch.eye(256, dtype=dtype, device=device))

p_residuals = []
q_residuals = []
p_uncertainties_aleatoric = []
q_uncertainties_aleatoric = []
p_uncertainties_epistemic = []
q_uncertainties_epistemic = []
for batch in tqdm.tqdm(validation_dataloader):
    systems, targets = batch
    systems = [system.to(device=device, dtype=dtype) for system in systems]
    targets = {
        name: target.to(device=device, dtype=dtype)
        for name, target in targets.items()
    }
    outputs = model(systems, eval_options, check_consistency=False)
    p_residuals.append(outputs[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_residuals.append(outputs[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    p_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    ll_features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    ll_features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    p_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_p, inverse_covariance_p, ll_features_p).cpu().numpy())
    q_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_q, inverse_covariance_q, ll_features_q).cpu().numpy())

p_residuals = np.concatenate(p_residuals, axis=0)
q_residuals = np.concatenate(q_residuals, axis=0)

p_uncertainties_aleatoric = np.concatenate(p_uncertainties_aleatoric, axis=0)
q_uncertainties_aleatoric = np.concatenate(q_uncertainties_aleatoric, axis=0)

p_uncertainties_epistemic = np.concatenate(p_uncertainties_epistemic, axis=0)
q_uncertainties_epistemic = np.concatenate(q_uncertainties_epistemic, axis=0)

p_squared_residuals = np.sum(p_residuals**2, axis=-1)
q_squared_residuals = np.sum(q_residuals**2, axis=-1)

p_epistemic_scale_0 = np.mean(p_squared_residuals / p_uncertainties_epistemic) * 0.5
q_epistemic_scale_0 = np.mean(q_squared_residuals / q_uncertainties_epistemic) * 0.5

p_aleatoric_scale_0 = np.mean(p_squared_residuals / p_uncertainties_aleatoric) * 0.5
q_aleatoric_scale_0 = np.mean(q_squared_residuals / q_uncertainties_aleatoric) * 0.5

def inverse_softplus(
    y: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0
) -> torch.Tensor:
    by = beta * y
    large = by > threshold
    safe_inv = (by + torch.log1p(-torch.exp(-by))) / beta
    return torch.where(large, y, safe_inv)

class ScaleNLL(torch.nn.Module):
    def __init__(self, aleatoric_scale_0, epistemic_scale_0):
        super().__init__()
        self.aleatoric_scale = torch.nn.Parameter(inverse_softplus(torch.tensor(aleatoric_scale_0, dtype=dtype, device=device)))
        self.epistemic_scale = torch.nn.Parameter(inverse_softplus(torch.tensor(epistemic_scale_0, dtype=dtype, device=device)))

    def forward(self, squared_residuals, uncertainties_aleatoric, uncertainties_epistemic):
        aleatoric_scale = torch.nn.functional.softplus(self.aleatoric_scale) + 1e-4
        epistemic_scale = torch.nn.functional.softplus(self.epistemic_scale) + 1e-4
        return torch.mean(
            torch.log(aleatoric_scale * uncertainties_aleatoric + epistemic_scale * uncertainties_epistemic) +
            squared_residuals / (self.aleatoric_scale * uncertainties_aleatoric + self.epistemic_scale * uncertainties_epistemic)
        )


print()
print(q_aleatoric_scale_0, q_epistemic_scale_0)
scale_q = ScaleNLL(q_aleatoric_scale_0, q_epistemic_scale_0)
q_squared_residuals_tensor = torch.tensor(q_squared_residuals, dtype=dtype, device=device)
q_uncertainties_aleatoric_tensor = torch.tensor(q_uncertainties_aleatoric, dtype=dtype, device=device)
q_uncertainties_epistemic_tensor = torch.tensor(q_uncertainties_epistemic, dtype=dtype, device=device)
optimizer_q = torch.optim.LBFGS(scale_q.parameters(), lr=0.1)
for i in range(10):
    def closure():
        optimizer_q.zero_grad()
        loss = scale_q(
            q_squared_residuals_tensor,
            q_uncertainties_aleatoric_tensor,
            q_uncertainties_epistemic_tensor
        )
        loss.backward()
        print(torch.nn.functional.softplus(scale_q.aleatoric_scale).item(), torch.nn.functional.softplus(scale_q.epistemic_scale).item(), loss.item())
        return loss
    optimizer_q.step(closure)
q_aleatoric_scale = scale_q.aleatoric_scale.item()
q_epistemic_scale = scale_q.epistemic_scale.item()

print()
print(p_aleatoric_scale_0, p_epistemic_scale_0)
scale_p = ScaleNLL(p_aleatoric_scale_0, p_epistemic_scale_0)
p_squared_residuals_tensor = torch.tensor(p_squared_residuals, dtype=dtype, device=device)
p_uncertainties_aleatoric_tensor = torch.tensor(p_uncertainties_aleatoric, dtype=dtype, device=device)
p_uncertainties_epistemic_tensor = torch.tensor(p_uncertainties_epistemic, dtype=dtype, device=device)
optimizer_p = torch.optim.LBFGS(scale_p.parameters(), lr=0.1)
for i in range(10):
    def closure():
        optimizer_p.zero_grad()
        loss = scale_p(
            p_squared_residuals_tensor,
            p_uncertainties_aleatoric_tensor,
            p_uncertainties_epistemic_tensor
        )
        loss.backward()
        print(torch.nn.functional.softplus(scale_p.aleatoric_scale).item(), torch.nn.functional.softplus(scale_p.epistemic_scale).item(), loss.item())
        return loss
    optimizer_p.step(closure)

p_aleatoric_scale = torch.nn.functional.softplus(scale_p.aleatoric_scale).item()
p_epistemic_scale = torch.nn.functional.softplus(scale_p.epistemic_scale).item()

p_residuals = []
q_residuals = []
p_uncertainties_aleatoric = []
q_uncertainties_aleatoric = []
p_uncertainties_epistemic = []
q_uncertainties_epistemic = []
for batch in tqdm.tqdm(test_dataloader):
    systems, targets = batch
    systems = [system.to(device=device, dtype=dtype) for system in systems]
    targets = {
        name: target.to(device=device, dtype=dtype)
        for name, target in targets.items()
    }
    outputs = model(systems, eval_options, check_consistency=False)
    p_residuals.append(outputs[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_residuals.append(outputs[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy() - targets[f"mtt::delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    p_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_p_{time_steps}"].block().values.squeeze(-1).cpu().numpy())
    q_uncertainties_aleatoric.append(outputs[f"mtt::aleatoric_delta_{time_steps}_q"].block().values.squeeze(-1).cpu().numpy())
    ll_features_q = outputs[f"mtt::aux::delta_{time_steps}_q_last_layer_features"].block().values
    ll_features_p = outputs[f"mtt::aux::p_{time_steps}_last_layer_features"].block().values
    p_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_p, inverse_covariance_p, ll_features_p).cpu().numpy())
    q_uncertainties_epistemic.append(torch.einsum("ia, ab, ib -> i", ll_features_q, inverse_covariance_q, ll_features_q).cpu().numpy())

p_residuals = np.concatenate(p_residuals, axis=0)
q_residuals = np.concatenate(q_residuals, axis=0)

p_uncertainties_aleatoric = np.concatenate(p_uncertainties_aleatoric, axis=0) * p_aleatoric_scale
q_uncertainties_aleatoric = np.concatenate(q_uncertainties_aleatoric, axis=0) * q_aleatoric_scale

p_uncertainties_epistemic = np.concatenate(p_uncertainties_epistemic, axis=0) * p_epistemic_scale
q_uncertainties_epistemic = np.concatenate(q_uncertainties_epistemic, axis=0) * q_epistemic_scale

p_squared_residuals = np.sum(p_residuals**2, axis=-1)
q_squared_residuals = np.sum(q_residuals**2, axis=-1)

p_uncertainties = p_uncertainties_aleatoric + p_uncertainties_epistemic
q_uncertainties = q_uncertainties_aleatoric + q_uncertainties_epistemic








################################################################################

sigma = 1.0
func = lambda x: x * np.exp(-x**2/(2*sigma**2)) * 1.0/(sigma*np.sqrt(2*np.pi))
x = np.geomspace(0.01, 10, 10000)

from scipy.optimize import root_scalar

def pdf(x, sigma):
    return x * np.exp(-x**2/(2*sigma**2)) * 1.0/(sigma*np.sqrt(2*np.pi))

def find_where_pdf_is_c(c, sigma):
    # Finds the two values of x where the pdf is equal to c
    mode_value = pdf(sigma, sigma)
    if c > mode_value:
        raise ValueError("c must be less than mode_value")
    where_below_mode = root_scalar(lambda x: pdf(x, sigma) - c, bracket=[0, sigma]).root
    where_above_mode = root_scalar(lambda x: pdf(x, sigma) - c, bracket=[sigma, 100]).root
    return where_below_mode, where_above_mode

def pdf_integral(sigma, c):
    # Calculates the integral (analytical) of the pdf from x1 to x2,
    # where x1 and x2 are the two values of x where the pdf is equal to c
    x1, x2 = find_where_pdf_is_c(c, sigma)
    return np.exp(-x1**2/(2*sigma**2)) - np.exp(-x2**2/(2*sigma**2))

def find_fraction(sigma, fraction):
    # Finds the value of c where the integral of the pdf from x1 to x2 is equal to fraction,
    # where x1 and x2 are the two values of x where the pdf is equal to c
    mode_value = pdf(sigma, sigma)
    return root_scalar(lambda x: pdf_integral(sigma, x) - fraction, x0=mode_value-0.01, x1=mode_value-0.02).root

from scipy.stats import norm

desired_fractions = [
    norm.cdf(1, 0.0, 1.0) - norm.cdf(-1, 0.0, 1.0),  # 1 sigma
    norm.cdf(2, 0.0, 1.0) - norm.cdf(-2, 0.0, 1.0),  # 2 sigma
    norm.cdf(3, 0.0, 1.0) - norm.cdf(-3, 0.0, 1.0),  # 3 sigma
]
# print(desired_fractions)

sigmas = np.linspace(2e-5, 5e0, 5)

lower_bounds = []
upper_bounds = []
for desired_fraction in desired_fractions:
    lower_bounds.append([])
    upper_bounds.append([])
    for sigma in sigmas:
        isoline_value = find_fraction(sigma, desired_fraction)
        x1, x2 = find_where_pdf_is_c(isoline_value, sigma)
        lower_bounds[-1].append(x1)
        upper_bounds[-1].append(x2)

    additional_sigma = 100.0
    lower_bounds[-1].append(
        lower_bounds[-1][-1] + (lower_bounds[-1][-1] - lower_bounds[-1][-2])/(sigmas[-1] - sigmas[-2]) * additional_sigma
    )
    upper_bounds[-1].append(
        upper_bounds[-1][-1] + (upper_bounds[-1][-1] - upper_bounds[-1][-2])/(sigmas[-1] - sigmas[-2]) * additional_sigma
    )

    lower_bounds[-1] = np.array(lower_bounds[-1])
    upper_bounds[-1] = np.array(upper_bounds[-1])

sigmas = np.concatenate([sigmas, np.array([100.0])])

#############################################################


# Plot the results
plt.plot(q_uncertainties, q_squared_residuals, ".", markersize=1.0, rasterized=True)
plt.plot(sigmas, sigmas, color="k")
for l, u in zip(lower_bounds, upper_bounds):
    plt.plot(sigmas, l, color="k", linewidth=0.5)
    plt.plot(sigmas, u, color="k", linewidth=0.5)
plt.xlim(np.min(q_uncertainties), np.max(q_uncertainties))
plt.ylim(np.min(q_squared_residuals), np.max(q_squared_residuals))
plt.title("q")
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Uncertainty (eV)")
plt.ylabel("Error (eV)")
plt.tight_layout()
plt.savefig(f"q_{time_steps}_play.pdf", dpi=300)
plt.clf()

plt.plot(p_uncertainties, p_squared_residuals, ".", markersize=1.0, rasterized=True)
plt.plot(sigmas, sigmas, color="k")
for l, u in zip(lower_bounds, upper_bounds):
    plt.plot(sigmas, l, color="k", linewidth=0.5)
    plt.plot(sigmas, u, color="k", linewidth=0.5)
plt.xlim(np.min(p_uncertainties), np.max(p_uncertainties))
plt.ylim(np.min(p_squared_residuals), np.max(p_squared_residuals))
plt.title("p")
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Uncertainty (eV)")
plt.ylabel("Error (eV)")
plt.tight_layout()
plt.savefig(f"p_{time_steps}_play.pdf", dpi=300)
plt.clf()
