from models.neural_vae_model import NeuralVAE
from utils_scripts.utils_torch import (
    set_random_seed, zscore_2d, PairwiseDataset, rescale_to_01, rescale_to_minus1_1,
    zscore_by_column, warmup_then_decay_lr, evaluate_vae_on_loader
)
from utils_scripts.disentangle_metrics import factorvae_score, compute_unsupervised_sap, compute_mig

from sklearn.linear_model import Ridge, Lasso
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision

import numpy as np
import wandb
import matplotlib.pyplot as plt
import random
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description="Train a model with customizable parameters.")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--latent_size", type=int, default=16)
    parser.add_argument("--group_rank", type=int, default=1)
    parser.add_argument("--num_epochs", type=int, default=500)
    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--kl_weight", type=float, default=0.002)
    parser.add_argument("--guidance_weight", type=float, default=0.0)
    parser.add_argument("--tc_weight", type=float, default=0.0)
    parser.add_argument("--seed", type=int, default=2024)
    return parser.parse_args()


args = parse_args()
config = vars(args)

print("Training Configuration:")
print(config)

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"[GPU {i}] {torch.cuda.get_device_name(i)}")
else:
    print("No GPU available. Using CPU.")

neural_data_name = 'sorted_it_avg_data'
stimulus_file = 'datasets/stimulus_pose_features.npy'
stimulus_category_file = 'datasets/stimulus_category_ids.npy'
neural_data_file = f"datasets/{neural_data_name}.npy"

stimulus_data_pos = np.load(stimulus_file).astype(np.float32)
stimulus_data_category = np.expand_dims(np.load(stimulus_category_file).astype(np.float32), axis=1)
stimulus_data = np.concatenate((stimulus_data_pos, stimulus_data_category), axis=1)
neural_data = np.load(neural_data_file).astype(np.float32)[:, :58]
neural_dim, stimulus_dim = neural_data.shape[-1], stimulus_data.shape[-1]


def train_vae(model, train_dataloader, optimizer, num_epochs=200, device='cpu', record_training=False):
    model = model.to(device)
    pre_test_loss = 1e8

    for epoch in range(num_epochs):
        model.train()
        training_epoch_loss = 0.0
        recon_epoch_loss, label_epoch_loss = 0.0, 0.0
        kl_epoch_loss, tc_epoch_loss = 0.0, 0.0

        scheduler = optimizer

        for neural_batch, stimulus_batch in train_dataloader:
            neural_batch = neural_batch.to(device)
            stimulus_batch = stimulus_batch.to(device)

            recon_x, recon_y_pose, recon_y_category, mu, z, logvar = model(neural_batch)
            loss, recon_loss, label_loss, kl_loss, tc_loss = model.guide_vae_loss(
                recon_x, neural_batch, recon_y_pose, recon_y_category, stimulus_batch, z, mu, logvar
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            training_epoch_loss += loss.item()
            recon_epoch_loss += recon_loss.item()
            label_epoch_loss += label_loss.item()
            kl_epoch_loss += kl_loss.item()
            tc_epoch_loss += tc_loss.item()

        scheduler.step()

        testing_epoch_loss, testing_label_epoch_loss = 0.0, 0.0
        with torch.no_grad():
            for neural_batch, stimulus_batch in test_dataloader:
                neural_batch = neural_batch.to(device)
                stimulus_batch = stimulus_batch.to(device)

                recon_x, recon_y_pose, recon_y_category, mu, z, logvar = model(neural_batch)
                loss, recon_loss, label_loss, _, _ = model.guide_vae_loss(
                    recon_x, neural_batch, recon_y_pose, recon_y_category, stimulus_batch, z, mu, logvar
                )

                testing_epoch_loss += loss.item()
                testing_label_epoch_loss += label_loss.item()

        if record_training:
            wandb.log({
                "train_loss": training_epoch_loss,
                "test_loss": testing_epoch_loss,
                "test_label_loss": testing_label_epoch_loss,
                "recon_epoch_loss": recon_epoch_loss,
                "label_epoch_loss": label_epoch_loss,
                "kl_epoch_loss": kl_epoch_loss,
                "tc_epoch_loss": tc_epoch_loss,
            })

        if epoch % 5 == 0 and pre_test_loss > testing_epoch_loss:
            pre_test_loss = testing_epoch_loss
            torch.save(model.state_dict(), f'model_checkpoints/{exp_name}.pth')

        if epoch == 0 or (epoch + 1) % 20 == 0:
            print(f"Test Evaluation at epoch {epoch + 1}:")
            evaluate_vae_on_loader(model, test_dataloader, mode="test", device=device)

    print(f"Training Finished with seed {seed}")
    return model
