import argparse

import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, TensorDataset
from torcheval.metrics.functional import r2_score

from pdisvae import evaluate, inference, kl
from pdisvae.models.dcnn import BurgessDecoder, BurgessEncoder

## arguments
parser = argparse.ArgumentParser()
parser.add_argument("idx", type=int)
args = parser.parse_args()

independence_list = ["no", "strong", "weak"]
prior_list = ["normal", "logcosh"]

arg_index = np.unravel_index(
    args.idx,
    (
        len(independence_list),
        len(prior_list),
    ),
)
independence, prior = (
    independence_list[arg_index[0]],
    prior_list[arg_index[1]],
)
method = f"{independence}_{prior}"
print(f"independence: {independence}")
print(f"prior: {prior}")


## data
x = torch.load("data/x.pt").to(torch.float32)
z = torch.load("data/z.pt")
dataloader = DataLoader(
    TensorDataset(x), batch_size=805, shuffle=True, num_workers=8, pin_memory=True
)
n_total_samples, n_channels, height, width = x.shape
img_size = (n_channels, height, width)


## model
n_groups = 2
group_rank = 2
n_components = n_groups * group_rank

seed = 0
torch.manual_seed(seed)
encoder = BurgessEncoder(
    img_size=img_size, n_components=n_components, n_total_samples=n_total_samples
)
decoder = BurgessDecoder(img_size=img_size, n_components=n_components)
encoder.log_std.requires_grad = False

if independence == "no":
    kl_normal = kl.KLNormal(
        prior=prior,
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif independence == "strong":
    kl_normal = kl.KLNormal(
        prior=prior,
        n_groups=n_components,
        group_rank=1,
        n_total_samples=n_total_samples,
    )
elif independence == "weak":
    kl_normal = kl.KLNormal(
        prior=prior,
        n_groups=n_groups,
        group_rank=group_rank,
        n_total_samples=n_total_samples,
    )
else:
    raise ValueError("independence not recognized")


## Lightning module
# tag = "baseline"
# tag = "no_annealing"
tag = "gd"
results_file = f"results_{tag}"

n_epochs = 5000

extra_beta = np.zeros(n_epochs)
if independence != "no":
    extra_beta = np.linspace(4, 0, n_epochs)
    # extra_beta = 4 * np.ones(n_epochs)
# learning_rate = 5e-4
learning_rate = 1e-3

lit_btcvi = inference.LitBTCVI(encoder, decoder, kl_normal, extra_beta, learning_rate)

wandb_logger = WandbLogger(
    name=f"{method}",
    project=f"pdisvae-{__file__.split("/")[-2]}",
    save_dir=results_file,
    tags=[tag],
)

trainer = L.Trainer(
    logger=wandb_logger,
    min_epochs=n_epochs,
    max_epochs=n_epochs,
    enable_progress_bar=False,
)
trainer.fit(
    model=lit_btcvi,
    train_dataloaders=dataloader,
)

torch.save(encoder.state_dict(), f"{results_file}/{method}_encoder.pt")
torch.save(decoder.state_dict(), f"{results_file}/{method}_decoder.pt")
# encoder.load_state_dict(torch.load(f"{results_file}/{method}_encoder.pt"))
# decoder.load_state_dict(torch.load(f"{results_file}/{method}_decoder.pt"))


## evaluation
n_monte_carlo = 10
df = pd.DataFrame(
    index=np.arange(1),
    columns=[
        "conditional log-likelihood",
        "reconstruction MSE",
        "partial correlation",
        "total correlation",
        "latent $R^2$",
        "latent MSE",
    ],
)
with torch.no_grad():
    z_pred_mean, z_pred_log_std = encoder.forward(x)  # (n_samples, n_components)
    x_pred_mean = decoder.forward(z_pred_mean)

    df.at[0, "conditional log-likelihood"] = (
        decoder.log_prob(x_pred_mean, x).mean().item()
    )
    df.at[0, "reconstruction MSE"] = F.mse_loss(x_pred_mean, x).item()

    df.at[0, "partial correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=n_groups,
        n_monte_carlo=n_monte_carlo,
        seed=seed,
    ).item()
    df.at[0, "total correlation"] = evaluate.partial_correlation(
        z_pred_mean,
        z_pred_log_std,
        n_groups=encoder.n_components,
        n_monte_carlo=n_monte_carlo,
        seed=seed,
    ).item()

    z_test = torch.concat([z, 1e-4 * torch.randn(n_total_samples, 1)], dim=1)
    aligned_z_pred_mean = evaluate.align(z_pred_mean, z_test, n_groups=n_groups)

    df.at[0, "latent $R^2$"] = r2_score(aligned_z_pred_mean, z_test).item()
    df.at[0, "latent MSE"] = F.mse_loss(aligned_z_pred_mean, z_test).item()

df.to_csv(f"{results_file}/{method}.csv", index=False)
torch.save(aligned_z_pred_mean, f"{results_file}/{method}_aligned_z_pred_mean.pt")
