import torch
import argparse
import pickle
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch import tensor as tt
import time
import pyvinecopulib as pv


# Parse command-line arguments
parser = argparse.ArgumentParser(description="Train a model with a specified dataset.")
parser.add_argument("--dataset", type=str, required=True, help="Dataset name (e.g., 'magic_ecdf')")
parser.add_argument("--cv_seed", type=str, default=0, help="Seed for cross-validation.")
parser.add_argument('--vine_lib', type=str, default='py', help='Vine copula library to use torch/py (default: torchvinecopulib)')
parser.add_argument('--num_sims', type=int, default=1000, help='Number of simulations to generate (default: 30000)')
args = parser.parse_args()


# Use dataset name to construct file paths and variable names
dataset_name = args.dataset
cv_seed = int(args.cv_seed)
csv_path = f"Data/{dataset_name}.csv"


# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if the dataset file exists
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"Dataset file '{csv_path}' not found.")

# Load the dataset
X_ecdf = pd.read_csv(csv_path).values.astype(np.float32)

if dataset_name == 'cifar_ecdf': # doesnt help either...
    X_ecdf = X_ecdf.clip(1e-4,1-1e-4)
    print('clipped cifar!')

# Split into train and test sets
X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.2, random_state=cv_seed)

# Fit the Vine Copula model
training_start_time = time.time()
if args.vine_lib == 'torch':
    print("Using torchvinecopulib for Vine Copula fitting.")

    # Fit
    vc = VineCop(num_dim=X_ecdf_train.shape[1], is_cop_scale=True, num_step_grid=128)
    vc.fit(obs=tt(X_ecdf_train, device=device), 
        mtd_vine="rvine", 
        is_tll =True,
        mtd_tll='linear', 
        device=device)
    print(f"Fitting time: {time.time() - training_start_time:.2f} seconds")

    # Samples
    sampling_start_time = time.time()
    sims = vc.sample(num_sample=args.num_sims).numpy()
    print(f"Sampling time for seed {cv_seed}: {time.time() - sampling_start_time:.2f} seconds")
    sims_df = pd.DataFrame(sims, columns=[f"dim_{i}" for i in range(X_ecdf.shape[1])])
    sims_df.to_csv(f'Model_samples/Vine/{dataset_name}_simulated_seed_{cv_seed}.csv', index=False)
    print(f"Samples saved to Model_samples/Vine/{dataset_name}_simulated_seed_{cv_seed}.csv")

    # LL eval
    LL_start_time = time.time()
    ll_vine = vc.log_pdf(tt(X_ecdf_test, device=device)).cpu().numpy()
    print(f"Log-likelihood evaluation time: {time.time() - LL_start_time:.2f} seconds")
    ll_vine = pd.DataFrame(ll_vine, columns=['log_likelihood'])
    ll_vine.to_csv(f'Model_samples/Vine/{dataset_name}_log_likelihood_seed_{cv_seed}.csv', index=False)
    print(f"Log-likelihood saved to Model_samples/Vine/{dataset_name}_log_likelihood_seed_{cv_seed}.csv")
elif args.vine_lib == 'py':
    print("Using pyvinecopulib for Vine Copula fitting.")

    # Fit
    n, d = X_ecdf_train.shape
    #bw_scott = n**(-1./(d+4))
    if dataset_name in ['cifar_ecdf']:
        bw_scott = 0.1
    controls = pv.FitControlsVinecop(family_set=pv.BicopFamily.tll, 
                                     num_threads=24, 
                                     threshold=0.05) 
    cop = pv.Vinecop.from_data(X_ecdf_train, controls=controls)
    print(f"Fitting time: {time.time() - training_start_time:.2f} seconds")

    # Samples
    sampling_start_time = time.time()
    u_sim = cop.simulate(args.num_sims, num_threads=24)
    print(f"Sampling time for seed {cv_seed}: {time.time() - sampling_start_time:.2f} seconds")
    sims_df = pd.DataFrame(u_sim, columns=[f"dim_{i}" for i in range(X_ecdf.shape[1])])
    sims_df.to_csv(f'Model_samples/Vine/clip_{dataset_name}_simulated_seed_{cv_seed}_para.csv', index=False)
    print(f"Samples saved to Model_samples/Vine/clip_{dataset_name}_simulated_seed_{cv_seed}_para.csv")

    # LL eval
    LL_start_time = time.time()
    ll_vine = np.log(cop.pdf(X_ecdf_test, num_threads=24))
    ll_vine = pd.DataFrame(ll_vine, columns=['log_likelihood'])
    ll_vine.to_csv(f'Model_samples/Vine/clip_{dataset_name}_log_likelihood_seed_{cv_seed}_para.csv', index=False)
    print(f"Log-likelihood saved to Model_samples/Vine/clip_{dataset_name}_log_likelihood_seed_{cv_seed}_para.csv")

else:
    raise ValueError("Invalid vine library specified. Use 'torch' or 'py'.")

print('time taken to train vine copula:', time.time()-training_start_time, 's, dataset:', dataset_name, 'cv_seed:', cv_seed)
