import argparse

import lightning as L
import numpy as np
import pandas as pd
import torch
import volimg.utils
from lightning.pytorch.loggers import WandbLogger
from pdisvae import evaluate, inference, kl
from pdisvae.models.linear import LinearDecoder, LinearEncoder
from torch.utils.data import DataLoader, TensorDataset
from torcheval.metrics.functional import mean_squared_error, r2_score

from voltage import make_video, prepare_data

session_list = volimg.utils.session_list
print(session_list)

## arguments
parser = argparse.ArgumentParser()
parser.add_argument("idx", type=int)
args = parser.parse_args()

n_components_list = [1, 2, 3, 4, 6, 12]
n_groups_list = [1, 2, 3, 4, 6, 12]
n_components_n_groups_list = [(12, n_groups) for n_groups in n_groups_list] + [
    (n_components, n_components) for n_components in n_components_list[:-1]
]

arg_index = np.unravel_index(
    args.idx,
    (len(n_components_n_groups_list),),
)
(n_components_n_groups,) = (n_components_n_groups_list[arg_index[0]],)
method = f"{n_components_n_groups}"
n_components, n_groups = n_components_n_groups
print(f"n_components: {n_components}")
print(f"n_groups: {n_groups}")


## data
x = prepare_data(session_list[0], outcome="success", trial_side="left")
n_total_samples, obs_dim = x.shape
dataloader = DataLoader(TensorDataset(x), batch_size=128, shuffle=True)


## model
group_rank = int(n_components / n_groups)

seed = 0
torch.manual_seed(seed)
encoder = LinearEncoder(
    obs_dim=obs_dim, n_components=n_components, n_total_samples=n_total_samples
)
encoder.log_std.requires_grad = False

decoder = LinearDecoder(obs_dim=obs_dim, n_components=n_components)

kl_normal = kl.KLNormal(
    prior="normal",
    n_groups=n_groups,
    group_rank=group_rank,
    n_total_samples=n_total_samples,
)


## Lightning module
tag = "baseline"
# tag = "no_annealing"
results_file = f"results_{tag}"

# n_epochs = 5000

# extra_beta = np.linspace(4, 0, n_epochs)
# # extra_beta = 4 * np.ones(n_epochs)
# learning_rate = 2e-4

# lit_btcvi = inference.LitBTCVI(encoder, decoder, kl_normal, extra_beta, learning_rate)

# wandb_logger = WandbLogger(
#     name=f"{method}",
#     project=f"pdisvae-{__file__.split("/")[-2]}",
#     save_dir=results_file,
#     tags=[tag],
# )

# trainer = L.Trainer(
#     logger=wandb_logger,
#     min_epochs=n_epochs,
#     max_epochs=n_epochs,
#     enable_progress_bar=False,
# )
# trainer.fit(
#     model=lit_btcvi,
#     train_dataloaders=dataloader,
# )

# torch.save(encoder.state_dict(), f"{results_file}/{method}_encoder.pth")
# torch.save(decoder.state_dict(), f"{results_file}/{method}_decoder.pth")
encoder.load_state_dict(torch.load(f"{results_file}/{method}_encoder.pth"))
decoder.load_state_dict(torch.load(f"{results_file}/{method}_decoder.pth"))


## evaluation
n_monte_carlo = 10
df = pd.DataFrame(
    index=np.arange(1),
    columns=[
        "conditional log-likelihood",
        "reconstruction $R^2$",
        "recconstruction RMSE",
        "partial correlation",
    ],
)
with torch.no_grad():
    z_pred_mean, z_pred_log_std = encoder.forward(x)  # (n_samples, n_components)
    x_pred_mean = decoder.forward(z_pred_mean)

    df.at[0, "conditional log-likelihood"] = (
        decoder.log_prob(x_pred_mean, x).mean().item()
    )
    df.at[0, "reconstruction $R^2$"] = r2_score(x_pred_mean, x).item()
    df.at[0, "reconstruction RMSE"] = mean_squared_error(x_pred_mean, x).item() ** 0.5

    df.at[0, "partial correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
    ).item()
df.to_csv(f"{results_file}/{method}.csv", index=False)


with torch.no_grad():
    z_pred_mean, z_pred_log_std = encoder(x)

torch.save(z_pred_mean, f"{results_file}/{method}_z_pred_mean.pth")

pca_ica_check_result = evaluate.pca_ica_check(
    z_pred_mean,
    n_groups=n_groups,
    n_monte_carlo=n_monte_carlo,
)

torch.save(pca_ica_check_result, f"{results_file}/{method}_pca_ica_check_result.pth")

make_video(
    tag=tag,
    method=method,
    x=x,
)
