import argparse

import lightning as L
import numpy as np
import pandas as pd
import torch
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, TensorDataset
from torcheval.metrics.functional import r2_score

from pdisvae import evaluate, inference, kl, utils
from pdisvae.models.linear import LinearDecoder, LinearEncoder
from pdisvae.models.lint import LintDecoder

## arguments
parser = argparse.ArgumentParser()
parser.add_argument("idx", type=int)
args = parser.parse_args()

structure_list = ["linear", "LINT"][0:1]
independence_list = ["no", "strong", "weak"]
prior_list = ["normal", "logcosh"]
seed_list = np.arange(10)

arg_index = np.unravel_index(
    args.idx,
    (
        len(structure_list),
        len(independence_list),
        len(prior_list),
        len(seed_list),
    ),
)
structure, independence, prior, seed = (
    structure_list[arg_index[0]],
    independence_list[arg_index[1]],
    prior_list[arg_index[2]],
    seed_list[arg_index[3]],
)
method = f"{structure}_{independence}_{prior}"
print(f"structure: {structure}")
print(f"independence: {independence}")
print(f"prior: {prior}")
print(f"seed: {seed}")


## data
trial = 0
df_data = pd.read_pickle("data/data.pkl")
x_train = df_data.at[trial, "x_train"]
z_train = df_data.at[trial, "z_train"]
train_dataloader = DataLoader(TensorDataset(x_train), batch_size=128, shuffle=False)
n_total_samples, obs_dim = x_train.shape


## model
basis = utils.exp_basis(1, 5, 5)
n_groups = 3
group_rank = 2
n_components = n_groups * group_rank

torch.manual_seed(seed)
encoder = LinearEncoder(
    obs_dim=obs_dim, n_components=n_components, n_total_samples=n_total_samples
)
encoder.log_std.requires_grad = False

if structure == "linear":
    decoder = LinearDecoder(obs_dim=obs_dim, n_components=n_components)
elif structure == "LINT":
    convolved_history = utils.convolve_spikes_with_basis(x_train.T[None, :, :], basis)[
        0
    ]
    decoder = LintDecoder(convolved_history, n_components=n_components)

if independence == "no":
    kl_normal = kl.KLNormal(
        prior=prior,
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif independence == "strong":
    kl_normal = kl.KLNormal(
        prior=prior,
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif independence == "weak":
    kl_normal = kl.KLNormal(
        prior=prior,
        n_groups=n_groups,
        group_rank=group_rank,
        n_total_samples=n_total_samples,
    )
else:
    raise ValueError("independence not recognized")


## Lightning module
tag = "fixed_std"
results_file = f"results_{tag}"
# n_epochs = 5000

# extra_beta = np.zeros(n_epochs)
# if independence != "no":
#     extra_beta = np.linspace(4, 0, n_epochs)
# learning_rate = 5e-4

# lit_btcvi = inference.LitBTCVI(encoder, decoder, kl_normal, extra_beta, learning_rate)

# wandb_logger = WandbLogger(
#     name=f"{method}_{seed}",
#     project=f"pdisvae-{__file__.split("/")[-2]}",
#     save_dir=results_file,
#     tags=[tag],
# )

# trainer = L.Trainer(
#     logger=wandb_logger,
#     min_epochs=n_epochs,
#     max_epochs=n_epochs,
#     enable_progress_bar=False,
# )
# trainer.fit(
#     model=lit_btcvi,
#     train_dataloaders=train_dataloader,
# )

# torch.save(encoder.state_dict(), f"{results_file}/{method}_{seed}_encoder.pt")
# torch.save(decoder.state_dict(), f"{results_file}/{method}_{seed}_decoder.pt")
encoder.load_state_dict(torch.load(f"{results_file}/{method}_{seed}_encoder.pt"))
decoder.load_state_dict(torch.load(f"{results_file}/{method}_{seed}_decoder.pt"))


## evaluation
n_monte_carlo = 10
df = pd.DataFrame(
    index=np.arange(1),
    columns=[
        "conditional log-likelihood",
        "reconstruction $R^2$",
        "partial correlation",
        "total correlation",
        "latent $R^2$",
    ],
)
with torch.no_grad():
    z_pred_mean, z_pred_log_std = encoder.forward(x_train)  # (n_samples, n_components)
    x_pred_mean = decoder.forward(z_pred_mean)

    df.at[0, "conditional log-likelihood"] = (
        decoder.log_prob(x_pred_mean, x_train).mean().item()
    )
    df.at[0, "reconstruction $R^2$"] = r2_score(x_pred_mean, x_train).item()

    df.at[0, "partial correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
    ).item()
    df.at[0, "total correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=encoder.n_components,
        n_monte_carlo=n_monte_carlo,
    ).item()

    aligned_z_pred_mean = evaluate.align(z_pred_mean, z_train, n_groups=n_groups)
    df.at[0, "latent $R^2$"] = r2_score(aligned_z_pred_mean, z_train).item()

    pca_ica_check_result = evaluate.pca_ica_check(
        z_pred_mean,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
    )

df.to_csv(f"{results_file}/{method}_{seed}.csv", index=False)
torch.save(
    aligned_z_pred_mean, f"{results_file}/{method}_{seed}_aligned_z_pred_mean.pth"
)
torch.save(
    pca_ica_check_result, f"{results_file}/{method}_{seed}_pca_ica_check_result.pth"
)
