import argparse

import lightning as L
import numpy as np
import pandas as pd
import torch
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, TensorDataset
from torcheval.metrics.functional import r2_score

from pdisvae import evaluate, inference, kl, utils
from pdisvae.models.linear import LinearDecoder, LinearEncoder
from pdisvae.models.lint import LintDecoder

## arguments
parser = argparse.ArgumentParser()
parser.add_argument("idx", type=int)
args = parser.parse_args()

name_list = ["VAE", "ICA", "ISA-VAE", "$\\beta$-TCVAE", "PDisVAE"]
seed_list = np.arange(10)

arg_index = np.unravel_index(
    args.idx,
    (
        len(name_list),
        len(seed_list),
    ),
)
name, seed = (
    name_list[arg_index[0]],
    seed_list[arg_index[1]],
)

print(f"name: {name}")
print(f"seed: {seed}")


## data
trial = 0
df_data = pd.read_pickle("data/data.pkl")
x_train = df_data.at[trial, "x_train"]
z_train = df_data.at[trial, "z_train"]
train_dataloader = DataLoader(TensorDataset(x_train), batch_size=128, shuffle=False)
n_total_samples, obs_dim = x_train.shape


## model
n_groups = 3
group_rank = 2
n_components = n_groups * group_rank

torch.manual_seed(seed)
encoder = LinearEncoder(
    obs_dim=obs_dim, n_components=n_components, n_total_samples=n_total_samples
)
encoder.log_std.requires_grad = False

decoder = LinearDecoder(obs_dim=obs_dim, n_components=n_components)

if name == "VAE":
    kl_normal = kl.KLNormal(
        prior="normal",
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif name == "ICA":
    kl_normal = kl.KLNormal(
        prior="logcosh",
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif name == "ISA-VAE":
    kl_normal = kl.KLNormal(
        prior="lpnested",
        n_groups=n_groups,
        group_rank=group_rank,
        n_total_samples=n_total_samples,
    )
elif name == "$\\beta$-TCVAE":
    kl_normal = kl.KLNormal(
        prior="normal",
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif name == "PDisVAE":
    kl_normal = kl.KLNormal(
        prior="normal",
        n_groups=n_groups,
        group_rank=group_rank,
        n_total_samples=n_total_samples,
    )
else:
    raise ValueError("name not recognized")

## Lightning module
tag = "baseline"
results_file = f"results_{tag}"
n_epochs = 5000

extra_beta = np.zeros(n_epochs)
if name in ["$\\beta$-TCVAE", "PDisVAE"]:
    extra_beta = np.linspace(4, 0, n_epochs)
learning_rate = 5e-4

lit_btcvi = inference.LitBTCVI(encoder, decoder, kl_normal, extra_beta, learning_rate)

wandb_logger = WandbLogger(
    name=f"{name}_{seed}",
    project=f"icml2025-{__file__.split("/")[-2]}",
    save_dir=results_file,
    tags=[tag],
)

trainer = L.Trainer(
    logger=wandb_logger,
    min_epochs=n_epochs,
    max_epochs=n_epochs,
    enable_progress_bar=False,
)
trainer.fit(
    model=lit_btcvi,
    train_dataloaders=train_dataloader,
)

torch.save(encoder.state_dict(), f"{results_file}/{name}_{seed}_encoder.pt")
torch.save(decoder.state_dict(), f"{results_file}/{name}_{seed}_decoder.pt")
# encoder.load_state_dict(torch.load(f"{results_file}/{name}_{seed}_encoder.pt"))
# decoder.load_state_dict(torch.load(f"{results_file}/{name}_{seed}_decoder.pt"))


## evaluation
n_monte_carlo = 10
df = pd.DataFrame(
    index=np.arange(1),
    columns=[
        "conditional log-likelihood",
        "reconstruction $R^2$",
        "partial correlation",
        "total correlation",
        "latent $R^2$",
    ],
)
with torch.no_grad():
    z_pred_mean, z_pred_log_std = encoder.forward(x_train)  # (n_samples, n_components)
    x_pred_mean = decoder.forward(z_pred_mean)

    df.at[0, "conditional log-likelihood"] = (
        decoder.log_prob(x_pred_mean, x_train).mean().item()
    )
    df.at[0, "reconstruction $R^2$"] = r2_score(x_pred_mean, x_train).item()

    df.at[0, "partial correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
    ).item()
    df.at[0, "total correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=encoder.n_components,
        n_monte_carlo=n_monte_carlo,
    ).item()

    aligned_z_pred_mean = evaluate.align(z_pred_mean, z_train, n_groups=n_groups)
    df.at[0, "latent $R^2$"] = r2_score(aligned_z_pred_mean, z_train).item()

    pca_ica_check_result = evaluate.pca_ica_check(
        z_pred_mean,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
    )

df.to_csv(f"{results_file}/{name}_{seed}.csv", index=False)
torch.save(aligned_z_pred_mean, f"{results_file}/{name}_{seed}_aligned_z_pred_mean.pth")
torch.save(
    pca_ica_check_result, f"{results_file}/{name}_{seed}_pca_ica_check_result.pth"
)
