import argparse

import lightning as L
import numpy as np
import torch
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CelebA

from pdisvae import inference, kl
from pdisvae.models.dcnn import BurgessDecoder, BurgessEncoder

## arguments
parser = argparse.ArgumentParser()
parser.add_argument("idx", type=int)
args = parser.parse_args()

n_groups_list = [1, 2, 3, 4, 6, 12]

arg_index = np.unravel_index(
    args.idx,
    (
        len(n_groups_list),
    ),
)
n_groups, = (
    n_groups_list[arg_index[0]],
)
method = f"{n_groups}"
print(f"n_groups: {n_groups}")


## data
batch_size = 256
compose = transforms.Compose(
    [
        transforms.CenterCrop((178, 178)),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ]
)
dataset = CelebA(root="data", split="train", download=True, transform=compose)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
n_total_samples = len(dataset)
img_size = (3, 64, 64)


## model
n_components = 12
group_rank = int(n_components / n_groups)

seed = 0
torch.manual_seed(seed)
encoder = BurgessEncoder(img_size=img_size, n_components=n_components, n_total_samples=batch_size)
decoder = BurgessDecoder(img_size=img_size, n_components=n_components)
encoder.log_std.requires_grad = False

kl_normal = kl.KLNormal(
    prior="normal",
    n_groups=n_groups,
    group_rank=group_rank,
    # n_total_samples=n_total_samples,
)


## Lightning module
# tag = "baseline"
tag = "no_annealing"
results_file = f"results_{tag}"

# n_epochs = 20
n_epochs = 50

# extra_beta = np.linspace(4, 0, n_epochs)
extra_beta = 4 * np.ones(n_epochs)
learning_rate = 5e-4

lit_btcvi = inference.LitBTCVI(encoder, decoder, kl_normal, extra_beta, learning_rate)

wandb_logger = WandbLogger(
    name=f"{method}_{seed}",
    project=f"pdisvae-{__file__.split("/")[-2]}",
    save_dir=results_file,
    tags=[tag],
)

trainer = L.Trainer(
    logger=wandb_logger,
    min_epochs=n_epochs,
    max_epochs=n_epochs,
    enable_progress_bar=False,
)
trainer.fit(
    model=lit_btcvi,
    train_dataloaders=dataloader,
)

torch.save(encoder.state_dict(), f"{results_file}/{method}_encoder.pt")
torch.save(decoder.state_dict(), f"{results_file}/{method}_decoder.pt")
# encoder.load_state_dict(torch.load(f"{results_file}/{method}_encoder.pt"))
# decoder.load_state_dict(torch.load(f"{results_file}/{method}_decoder.pt"))
