import torch
import argparse
import pickle
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch import tensor as tt
import time
import torch.optim as optim
from CDC import *
from torch.distributions.multivariate_normal import MultivariateNormal as tmvnorm
import time
from scipy import stats as scs
import ot
import hamiltorch

def W2(x,y):
    return torch.sqrt(ot.emd2(torch.ones(x.shape[0])/x.shape[0], torch.ones(y.shape[0])/y.shape[0], ot.dist(x, y)))

tnorm = torch.distributions.normal.Normal(loc=0, scale=1)



# Parse command-line arguments
parser = argparse.ArgumentParser(description="Train a model with a specified dataset.")
parser.add_argument("--dataset", type=str, required=True, help="Dataset name (e.g., 'magic_ecdf')")
parser.add_argument("--epochs", type=str, default=50, help="Number of epochs to train the model.")
parser.add_argument("--cv_seed", type=int, default=0, help="Seed for cross-validation.")
parser.add_argument("--GG_cdc", type=int,default=0, help="Use a Gaussian-Guided CDC; use a Gaussian as terminal distribution (True=1/False=0).")
parser.add_argument("--test", type=int, default=0, help="Use test set (True=1/False=0).")
args = parser.parse_args()

# Use dataset name to construct file paths and variable names
dataset_name = args.dataset
cv_seed = int(args.cv_seed)
csv_path = f"Data/{dataset_name}.csv"

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if the dataset file exists
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"Dataset file '{csv_path}' not found.")

# Load the dataset
X_ecdf = pd.read_csv(csv_path).values.astype(np.float32)

# transform to Gaussian scale
X_ecdf = scs.norm.ppf(X_ecdf.clip(1e-5,1-1e-5))

# Split into train and test sets
X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.2, random_state=cv_seed)

if args.test == 0:
    # take the 80% for training and 20% for validation
    X_ecdf_train_ = X_ecdf_train[:int(0.8 * len(X_ecdf_train))]
    X_ecdf_val = X_ecdf_train[int(0.8 * len(X_ecdf_train)):]
    if dataset_name in ['digits_ecdf', 'mnist_ecdf']:
        X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.5, random_state=cv_seed)
        X_ecdf_val = X_ecdf_train[int(0.8 * len(X_ecdf_train)):]
        X_ecdf_train_ = X_ecdf_train[:int(0.8 * len(X_ecdf_train))]
        print('Val, digits or mnist, 0.5 split')
    
    print('Testing on validation set of shape: ', X_ecdf_test.shape)
if args.test == 1:
    if dataset_name == 'robocup_train':
        X_ecdf_train_ = X_ecdf
        X_ecdf_val = scs.norm.ppf(pd.read_csv('Data/robocup_test.csv').values.astype(np.float32).clip(1e-5,1-1e-5))
        X_ecdf_train = X_ecdf_train_
        X_ecdf_test = X_ecdf_val
        print('robocup, train_min:', X_ecdf_train.min().item(), 'train_max:', X_ecdf_train.max().item(), 'test min:', X_ecdf_test.min().item(), 'test max:', X_ecdf_test.max().item())
    print(dataset_name, ' actual test run with train data:', X_ecdf_train.shape, 'test data:', X_ecdf_test.shape)

    if dataset_name in ['digits_ecdf', 'mnist_ecdf']:
        X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.5, random_state=cv_seed)
        print('Test, digits or mnist, 0.5 split')

# Dataloaders
X_train_tensor = torch.tensor(X_ecdf_train, dtype=torch.float32)
train_dataset = TensorDataset(X_train_tensor)

batch_size =  512  # Adjust batch size as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

X_ecdf_train = tt(X_ecdf_train).float().to(device)
X_ecdf_test = tt(X_ecdf_test).float().to(device)


T_max = 100
num_timesteps = 2
timesteps = T_max * torch.linspace(0, 1, num_timesteps).to(device)
num_epochs = int(args.epochs)
lr = 0.00005
if dataset_name in ['digits_ecdf', 'mnist_ecdf', 'cifar_ecdf']:
    lr = 1e-4
    if dataset_name == 'cifar_ecdf': # 24 AUG 1e-3
        lr = 1e-3
if args.GG_cdc == 1:
    GG_cdc = True
else:
    GG_cdc = False

if GG_cdc:
    # Define OU covariance matrix for Gaussian copula CDC
    corr_mat = torch.corrcoef(X_train_tensor.T).to(device)
    if dataset_name == 'Dry_Bean_ecdf':
        corr_mat = corr_mat + 1e-5 * torch.eye(corr_mat.shape[0], device=corr_mat.device)
        corr_mat = corr_mat/(1+1e-5)
        print('downed down corr_mat a bit for the Dry_Bean_ecdf dataset')
    print("Using Gaussian-Guided CDC with OU covariance matrix. Shape: ", corr_mat.shape)

print(f"Training with dataset: {dataset_name}, of shape {X_train_tensor.shape}, epochs: {num_epochs}, num_timesteps: {num_timesteps}, GG_cdc: {GG_cdc}, cv_seed: {cv_seed}")



ce_loss_fn = nn.CrossEntropyLoss()
mse_loss_fn = nn.MSELoss()



# define model
if dataset_name in ['digits_ecdf', 'cifar_ecdf']:
    model = Unet_img_CDClassifier(input_dim = X_ecdf_train.shape[1],
                            device = device,
                            num_timesteps = num_timesteps, 
                            channels = 64, 
                            ch_mult = [1,2,2],
                            num_res_blocks = 2, 
                            attn_resolutions = [16,], 
                            time_steps = timesteps, 
                            corr_mat = None).to(device)
    print("Using U-Net model for dataset:", dataset_name)

    
else:
    model = ResNetCDClassifier(input_dim=X_train_tensor.shape[1],
                           device=device, 
                            num_timesteps=num_timesteps, 
                            time_steps=timesteps, 
                            hidden_dim=512, 
                            depth=6,
                            backbone="Resnet")

hmc_step_size = 0.1
hmc_num_sims = 50
if dataset_name in ['mnist_ecdf', 'cifar_ecdf']:
        # set the HMC jump much smaller so we get more accepted samples
        hmc_step_size = 0.05
        hmc_num_sims = 500

if dataset_name == 'Dry_Bean_ecdf':
        # set the HMC jump much smaller so we get more accepted samples
        hmc_step_size = 0.001
        hmc_num_sims = 100
print(dataset_name, 'hmc params: hmc_step_size | hmc_num_sims', hmc_step_size, hmc_num_sims)

if GG_cdc: # add the correlation matrix, it is needed during training
    model.corr_mat = corr_mat
    model.corr_mat_inv = torch.linalg.inv(corr_mat)

# num of params in model
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

optimizer = optim.Adam(model.parameters(), lr=lr)

loss_cum = torch.zeros(3, device=device)

training_start_time = time.time()

for epoch in (range(num_epochs+1)):
    # Sample batch indices
    idx = torch.randint(0, X_ecdf_train.size(0), (batch_size,))
    x0 = X_ecdf_train[idx]  # (B, D)
    # Sample random timesteps for each sample in batch
    t_idx = torch.randint(0, num_timesteps, (x0.shape[0],)).to(device)#torch.randint(0, num_timesteps, (batch_size,))
        
    # Simulate noisy data
    if GG_cdc:
        noise_OU = tmvnorm(loc=torch.zeros_like(x0).to(device),
                            covariance_matrix=model.corr_mat
                            ).sample().to(device)
    else:
        noise_OU = torch.randn_like(x0).to(device)  # Standard Gaussian noise
    x_t, noise = sample_ou_noised_discrete(x0, t_idx, timesteps, noise=noise_OU)
    x_t.requires_grad_()

    # Forward pass
    logits, denoiser = model(x_t, t_idx, return_score=False)
    # Loss
    ce_loss = ce_loss_fn(logits, t_idx)
    total_loss = ce_loss 

    # Backprop + optimize
    total_loss.backward()
    # Clip gradients to avoid exploding gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()
    with torch.no_grad():
        loss_cum[0] += total_loss
        loss_cum[1] += ce_loss
    
    if args.test == 0:
        if epoch % 100 == 0:
            if epoch % 1000 == 0:
                # Sampling
                if GG_cdc:
                    x_init = torch.distributions.multivariate_normal.MultivariateNormal(
                        loc=torch.zeros(X_ecdf_train.shape[1], device=device), covariance_matrix=corr_mat
                        ).sample((1,))
                else: 
                    x_init = torch.randn((1, X_ecdf_train.shape[1]), device=device)

                def log_prob_func(x):
                    cop_ratio = model.estimate_log_density_ratio(x.unsqueeze(0)).to(device)
                    ll_indep_gauss = torch.distributions.MultivariateNormal(torch.zeros(x.shape[0], device=device), torch.eye(x.shape[0], device=device)).log_prob(x)
                    return cop_ratio + ll_indep_gauss
                params_init = x_init.to(device).squeeze()
                sims = hamiltorch.sample(log_prob_func=log_prob_func, params_init=params_init, num_samples=25, step_size=hmc_step_size, num_steps_per_sample=hmc_num_sims)
                sims = torch.stack(sims)
                # save as npy
                np.save(f'Model_samples/CDC/val/ratio_{dataset_name}_iter{epoch}_timesteps_{num_timesteps}_NNet_GG{args.GG_cdc}_sims.npy',
                        sims.cpu().numpy())

                W2_sims = W2(X_train_tensor[:25].to(device).float(),sims.float())
                W2_truth = W2(X_train_tensor[:25].to(device).float(),X_train_tensor[:25].to(device).float())
                print(f"Epoch {epoch}, W2 Sims: {W2_sims.item():.4f}, W2 Truth: {W2_truth.item():.4f}")
                torch.save(model.state_dict(), f'Model_weights/ratio_{dataset_name}_iter{epoch}_timesteps_{num_timesteps}_NNet_GG{args.GG_cdc}_lr{lr}.pt')

            # LL eval
            with torch.no_grad():
                
                ll_train = model.estimate_log_density_ratio(X_ecdf_train[:1000]).to(device)
                ll_test = model.estimate_log_density_ratio(X_ecdf_test[:1000]).to(device)
                
                print(f"{epoch} -------------------- LL train {ll_train.mean().item():.5f} +- {ll_train.std().item():.5f}, LL Test {ll_test.mean().item():.5f} +- {ll_test.std().item():.5f}, CE={loss_cum[0].item():.5f}")
            
                loss_cum.zero_()


if args.test == 1:
    torch.save(model.state_dict(), f'Model_weights/Ratio/ratio_{dataset_name}_seed{cv_seed}_iter{epoch}_timesteps_{num_timesteps}_NNet_GG{args.GG_cdc}_lr{lr}.pt')

    print('time taken to train ratio:', time.time()-training_start_time, 's, dataset:', dataset_name, 'cv_seed:', cv_seed , 'epochs:', args.epochs)

    sampling_time = time.time()
    # Sampling

    def log_prob_func(x):
        cop_ratio = model.estimate_log_density_ratio(x.unsqueeze(0)).to(device)
        ll_indep_gauss = torch.distributions.MultivariateNormal(torch.zeros(x.shape[0], device=device), torch.eye(x.shape[0], device=device)).log_prob(x)
        return cop_ratio + ll_indep_gauss
    all_sims = []
    for i in range(10):
        # Choose a new starting point each time
        if GG_cdc:
            x_init = torch.distributions.multivariate_normal.MultivariateNormal(
                loc=torch.zeros(X_ecdf_train.shape[1], device=device), covariance_matrix=corr_mat
                ).sample((1,))
        else: 
            x_init = torch.randn((1, X_ecdf_train.shape[1]), device=device)
        params_init = x_init.to(device).squeeze()
        sims = hamiltorch.sample(log_prob_func=log_prob_func, params_init=params_init, num_samples=100, step_size=hmc_step_size, num_steps_per_sample=hmc_num_sims)
        sims = torch.stack(sims)
        all_sims.append(sims)
    all_sims = torch.cat(all_sims, dim=0)
    # save as npy
    np.save(f'Model_samples/Ratio/ratio_{dataset_name}_seed{cv_seed}_iter{epoch}_timesteps_{num_timesteps}_NNet_GG{args.GG_cdc}_lr{lr}_sims.npy',
            all_sims.cpu().numpy())

    print('time taken to sample ratio:', time.time()-sampling_time, 's, dataset:', dataset_name, 'cv_seed:', cv_seed , 'epochs:', args.epochs)

    with torch.no_grad():
                ll_train = model.estimate_log_density_ratio(X_ecdf_train[:1000]).to(device)
                if dataset_name=='mnist_ecdf':
                    ll_test = []
                    ll_loop_linspace = np.linspace(0,X_ecdf_test.shape[0],13)
                    for ll_loop in range(12):
                        ll_test_loop = model.estimate_log_density_ratio(X_ecdf_test[int(ll_loop_linspace[ll_loop]):int(ll_loop_linspace[ll_loop+1])]).to(device)
                        ll_test.append(ll_test_loop)
                    ll_test = torch.stack(ll_test)
                elif dataset_name=='cifar_ecdf':
                    ll_test = []
                    ll_loop_linspace = np.linspace(0,X_ecdf_test.shape[0],3)
                    for ll_loop in range(2):
                        ll_test_loop = model.estimate_log_density_ratio(X_ecdf_test[int(ll_loop_linspace[ll_loop]):int(ll_loop_linspace[ll_loop+1])]).to(device)
                        ll_test.append(ll_test_loop)
                    ll_test = torch.stack(ll_test)
                else:
                    ll_test = model.estimate_log_density_ratio(X_ecdf_test).to(device)
                print(f"{epoch} -------------------- LL train {ll_train.mean().item():.5f} +- {ll_train.std().item():.5f}, LL Test {ll_test.mean().item():.5f} +- {ll_test.std().item():.5f}, CE={loss_cum[0].item():.5f}")
            
    np.save(f'Model_samples/Ratio/ratio_{dataset_name}_seed{cv_seed}_iter{epoch}_timesteps_{num_timesteps}_NNet_GG{args.GG_cdc}_lr{lr}_LL.npy',
        ll_test.cpu().numpy())
        
print('time taken to train+ sample ratio:', time.time()-training_start_time, 's, dataset:', dataset_name, 'cv_seed:', cv_seed , 'epochs:', args.epochs)
print(f"{epoch} -------------------- LL train {ll_train.mean().item():.5f} +- {ll_train.std().item():.5f}, LL Test {ll_test.mean().item():.5f} +- {ll_test.std().item():.5f}, CE={loss_cum[0].item():.5f}")
            


