import argparse

import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, TensorDataset
from torcheval.metrics.functional import r2_score

from pdisvae import evaluate, inference, kl
from pdisvae.models.dcnn import BurgessDecoder, BurgessEncoder

## arguments
parser = argparse.ArgumentParser()
parser.add_argument("idx", type=int)
args = parser.parse_args()

name_list = ["VAE", "ICA", "ISA-VAE", "$\\beta$-TCVAE", "PDisVAE"]

arg_index = np.unravel_index(
    args.idx,
    (len(name_list),),
)
(name,) = (name_list[arg_index[0]],)

print(f"name: {name}")


## data
x = torch.load("data/x.pt").to(torch.float32)
z = torch.load("data/z.pt")
dataloader = DataLoader(
    TensorDataset(x), batch_size=805, shuffle=True, num_workers=8, pin_memory=True
)
n_total_samples, n_channels, height, width = x.shape
img_size = (n_channels, height, width)


## model
n_groups = 2
group_rank = 2
n_components = n_groups * group_rank

seed = 0
torch.manual_seed(seed)
encoder = BurgessEncoder(
    img_size=img_size, n_components=n_components, n_total_samples=n_total_samples
)
decoder = BurgessDecoder(img_size=img_size, n_components=n_components)
encoder.log_std.requires_grad = False

if name == "VAE":
    kl_normal = kl.KLNormal(
        prior="normal",
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif name == "ICA":
    kl_normal = kl.KLNormal(
        prior="logcosh",
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif name == "ISA-VAE":
    kl_normal = kl.KLNormal(
        prior="lpnested",
        n_groups=n_groups,
        group_rank=group_rank,
        n_total_samples=n_total_samples,
    )
elif name == "$\\beta$-TCVAE":
    kl_normal = kl.KLNormal(
        prior="normal",
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif name == "PDisVAE":
    kl_normal = kl.KLNormal(
        prior="normal",
        n_groups=n_groups,
        group_rank=group_rank,
        n_total_samples=n_total_samples,
    )
else:
    raise ValueError("name not recognized")


## Lightning module
tag = "baseline"
results_file = f"results_{tag}"

n_epochs = 5000

extra_beta = np.zeros(n_epochs)
if name in ["$\\beta$-TCVAE", "PDisVAE"]:
    extra_beta = np.linspace(4, 0, n_epochs)
learning_rate = 1e-3

lit_btcvi = inference.LitBTCVI(encoder, decoder, kl_normal, extra_beta, learning_rate)

wandb_logger = WandbLogger(
    name=f"{name}",
    project=f"icml2025-{__file__.split("/")[-2]}",
    save_dir=results_file,
    tags=[tag],
)

trainer = L.Trainer(
    logger=wandb_logger,
    min_epochs=n_epochs,
    max_epochs=n_epochs,
    enable_progress_bar=False,
)
trainer.fit(
    model=lit_btcvi,
    train_dataloaders=dataloader,
)

torch.save(encoder.state_dict(), f"{results_file}/{name}_encoder.pt")
torch.save(decoder.state_dict(), f"{results_file}/{name}_decoder.pt")
# encoder.load_state_dict(torch.load(f"{results_file}/{name}_encoder.pt"))
# decoder.load_state_dict(torch.load(f"{results_file}/{name}_decoder.pt"))


## evaluation
n_monte_carlo = 10
df = pd.DataFrame(
    index=np.arange(1),
    columns=[
        "conditional log-likelihood",
        "reconstruction MSE",
        "partial correlation",
        "total correlation",
        "latent $R^2$",
        "latent MSE",
    ],
)
with torch.no_grad():
    z_pred_mean, z_pred_log_std = encoder.forward(x)  # (n_samples, n_components)
    x_pred_mean = decoder.forward(z_pred_mean)

    df.at[0, "conditional log-likelihood"] = (
        decoder.log_prob(x_pred_mean, x).mean().item()
    )
    df.at[0, "reconstruction MSE"] = F.mse_loss(x_pred_mean, x).item()

    df.at[0, "partial correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
        seed=seed,
    ).item()
    df.at[0, "total correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=encoder.n_components,
        n_monte_carlo=n_monte_carlo,
        seed=seed,
    ).item()

    z_test = torch.concat([z, 1e-4 * torch.randn(n_total_samples, 1)], dim=1)
    aligned_z_pred_mean = evaluate.align(z_pred_mean, z_test, n_groups=n_groups)

    df.at[0, "latent $R^2$"] = r2_score(aligned_z_pred_mean, z_test).item()
    df.at[0, "latent MSE"] = F.mse_loss(aligned_z_pred_mean, z_test).item()

df.to_csv(f"{results_file}/{name}.csv", index=False)
torch.save(aligned_z_pred_mean, f"{results_file}/{name}_aligned_z_pred_mean.pt")
