import torch

from torch.utils.data import DataLoader
from my_datasets.polynomial import PolynomialDataset

import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from function_encoder.model.mlp import MLP
from function_encoder.function_encoder import BasisFunctions, FunctionEncoder
from function_encoder.losses import basis_normalization_loss
from function_encoder.utils.training import train_step
from function_encoder.utils.experiment_saver import ExperimentSaver, create_visualization_data_polynomial


import tqdm

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"


torch.manual_seed(42)

# Load dataset
dataset = PolynomialDataset(n_points=100, n_example_points=100, degree=3)
dataloader = DataLoader(dataset, batch_size=50)
dataloader_iter = iter(dataloader)

# Create model
def basis_function_factory():
    return MLP(layer_sizes=[1, 32, 1])


num_basis = 10
# Only use one basis function initially for progressive training
basis_functions = BasisFunctions(basis_function_factory())

model = FunctionEncoder(basis_functions).to(device)

# Train model

losses = []  # For plotting.
scores = []  # For plotting.
dataloader_coeffs = DataLoader(dataset, batch_size=100)
dataloader_coeffs_iter = iter(dataloader_coeffs)


def compute_explained_variance(model):
    _, _, example_X, example_y = next(dataloader_coeffs_iter)
    example_X = example_X.to(device)
    example_y = example_y.to(device)
    coefficients, G = model.compute_coefficients(example_X, example_y)

    coefficients_centered = coefficients - coefficients.mean(dim=0, keepdim=True)
    coefficients_cov = (
        torch.matmul(coefficients_centered.T, coefficients_centered)
        / coefficients.shape[0]
    )

    eigenvalues, eigenvectors = torch.linalg.eigh(coefficients_cov)
    eigenvalues = eigenvalues.flip(0)  # Flip to descending order

    explained_variance_ratio = eigenvalues / torch.sum(eigenvalues)

    # eigenvectors = eigenvectors.flip(1)  # Flip to descending order
    # fpc_scores = torch.matmul(coefficients_centered, eigenvectors)

    gram_eigenvalues, gram_eigenvectors = torch.linalg.eigh(G.mean(dim=0))
    gram_eigenvalues = gram_eigenvalues.flip(0)  # Flip to descending order

    return explained_variance_ratio, eigenvalues, gram_eigenvalues


def loss_function(model, batch):
    X, y, example_X, example_y = batch
    X = X.to(device)
    y = y.to(device)
    example_X = example_X.to(device)
    example_y = example_y.to(device)

    coefficients, G = model.compute_coefficients(example_X, example_y)
    y_pred = model(X, coefficients)

    pred_loss = torch.nn.functional.mse_loss(y_pred, y)

    return pred_loss


# Train the first basis function
num_epochs = 1000
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
with tqdm.tqdm(range(num_epochs), desc=f"basis 1/{num_basis}") as tqdm_bar:
    for epoch in tqdm_bar:
        batch = next(dataloader_iter)
        loss = train_step(model, optimizer, batch, loss_function)
        losses.append(loss)  # Only the final loss
        tqdm_bar.set_postfix({"loss": f"{loss:.2e}"})


model.eval()
with torch.no_grad():
    explained_variance_ratio, *_ = compute_explained_variance(model)
    scores.append(explained_variance_ratio)

# Train the remaining basis functions progressively
for k in range(num_basis - 1):

    # Freeze all existing parameters except the new basis function
    for param in model.parameters():
        param.requires_grad = False

    # Create a new basis function and add it to the model
    new_basis_function = basis_function_factory()
    for param in new_basis_function.parameters():
        param.requires_grad = True

    new_basis_function = new_basis_function.to(device)
    model.basis_functions.basis_functions.append(new_basis_function)

    # Select only the trainable parameters
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(trainable_params, lr=1e-3)

    with tqdm.tqdm(range(num_epochs), desc=f"basis {k + 2}/{num_basis}") as tqdm_bar:
        for epoch in tqdm_bar:
            batch = next(dataloader_iter)
            loss = train_step(model, optimizer, batch, loss_function)
            losses.append(loss)
            tqdm_bar.set_postfix({"loss": f"{loss:.2e}"})

    model.eval()
    with torch.no_grad():
        explained_variance_ratio, *_ = compute_explained_variance(model)
        scores.append(explained_variance_ratio)

# Plot results

import matplotlib.pyplot as plt

# Publication formatting
plt.rcParams.update({
    'font.size': 8,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.format': 'png',
    'lines.markersize': 3,
    'legend.fontsize': 8,
    'legend.handlelength': 1.0,
    'legend.handletextpad': 0.3,
    'legend.columnspacing': 0.5
})

model.eval()
with torch.no_grad():
    dataloader_eval = DataLoader(dataset, batch_size=1)
    batch = next(iter(dataloader_eval))

    X, y, example_X, example_y = batch
    X = X.to(device)
    y = y.to(device)
    example_X = example_X.to(device)
    example_y = example_y.to(device)

    idx = torch.argsort(X, dim=1, descending=False)
    X = torch.gather(X, dim=1, index=idx)
    y = torch.gather(y, dim=1, index=idx)

    coefficients, _ = model.compute_coefficients(example_X, example_y)
    y_pred = model(X, coefficients)

    X = X.squeeze(0).cpu().numpy()
    y_pred = y_pred.squeeze(0).cpu().numpy()
    y = y.squeeze(0).cpu().numpy()

    example_X = example_X.squeeze(0).cpu().numpy()
    example_y = example_y.squeeze(0).cpu().numpy()

    # Plot the results
    fig, ax = plt.subplots()
    ax.plot(X, y, label="True")
    ax.plot(X, y_pred, label="Predicted")
    ax.scatter(example_X, example_y, label="Data", color="red")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.legend()
    plt.tight_layout()
    plt.savefig('plot_outputs/poly_func.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Visualize individual basis functions
    fig, axes = plt.subplots(2, 4,figsize=(6, 3))
    axes = axes.flatten()
    X_plot = torch.linspace(-1, 1, 100).unsqueeze(1).unsqueeze(0).to(device)
    for i, basis_fn in enumerate(model.basis_functions.basis_functions):
        if i >= num_basis or i >= len(axes):
            break
        basis_output = basis_fn(X_plot)
        axes[i].plot(X_plot[0].cpu().numpy(), basis_output[0].detach().cpu().numpy())
        axes[i].set_title(f"φ{i+1}")
        axes[i].set_xlabel('x')
        axes[i].set_ylabel('φ(x)')

    plt.tight_layout()
    plt.savefig('plot_outputs/poly_basis.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Plot loss and explained variance
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(9, 3))

    # Plot loss
    ax1.plot(losses)
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("MSE")
    ax1.grid(True)
    ax1.set_yscale("log")

    # Plot explained variance ratio
    for i in range(len(scores)):
        scores[i] = scores[i].cpu().numpy()
        ax2.plot(
            range(1, len(scores[i]) + 1),
            scores[i],
            marker="o",
            markersize=3,
            label=f"k = {i + 1}",
        )
    ax2.set_xlabel("Eigenvalue Index")
    ax2.set_ylabel("Explained Variance Ratio")
    ax2.set_yscale("log")
    ax2.legend()
    ax2.grid(True)

    # Plot the eigenvalues of the coefficients
    _, eigenvalues, gram_eigenvalues = compute_explained_variance(model)
    eigenvalues = eigenvalues.cpu().numpy()
    gram_eigenvalues = gram_eigenvalues.cpu().numpy()

    ax3.plot(
        range(1, len(eigenvalues) + 1),
        eigenvalues,
        marker="o",
        markersize=3,
        label="Covariance Matrix",
    )
    # ax3.plot(
    #     range(1, len(gram_eigenvalues) + 1),
    #     gram_eigenvalues,
    #     marker="o",
    #     markersize=3,
    #     label="Gram Matrix",
    # )
    ax3.set_xlabel("Eigenvalue Index")
    ax3.set_ylabel("Eigenvalue")
    # ax3.set_yscale("log")
    ax3.legend()
    ax3.grid(True)

    plt.tight_layout()
    plt.savefig('plot_outputs/poly_progressive_unsize.png', dpi=300, bbox_inches='tight')
    plt.show()

# Save experiment data
saver = ExperimentSaver()

# Prepare visualization data
viz_data = create_visualization_data_polynomial(
    X_sorted=X,
    y_sorted=y,
    y_pred=y_pred,
    example_X=example_X,
    example_y=example_y,
    basis_outputs=[basis_fn(X_plot)[0].detach().cpu().numpy()
                   for basis_fn in model.basis_functions.basis_functions]
)

# Prepare and save experiment data
experiment_data = saver.prepare_progressive_data(
    problem_type="polynomial",
    num_basis=num_basis,
    losses=losses,
    scores=scores,
    eigenvalues=eigenvalues,
    gram_eigenvalues=gram_eigenvalues,
    visualization_data=viz_data,
    dataset_params={
        "name": "poly_degree3",
        "n_points": 100,
        "n_example_points": 100,
        "degree": 3
    },
    training_params={
        "num_epochs": num_epochs,
        "learning_rate": 1e-3,
        "batch_size": 50
    }
)

saver.save_experiment("polynomial", "progressive", experiment_data, dataset_name="polynomial_d3")
