import torch
import argparse
import pickle
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch import tensor as tt
import time
from HFC import *
from Unet import My_UNet, UnetWrapper




# Parse command-line arguments
parser = argparse.ArgumentParser(description="Train a model with a specified dataset.")
parser.add_argument("--dataset", type=str, required=True, help="Dataset name (e.g., 'magic_ecdf')")
parser.add_argument("--epochs", type=str, default=50, help="Number of epochs to train the model.")
parser.add_argument("--cv_seed", type=str, default=0, help="Seed for cross-validation.")
parser.add_argument("--model_type", type=str, default='Unet', help="Model type to use (e.g., 'MLP', 'Unet').")
args = parser.parse_args()

# Use dataset name to construct file paths and variable names
dataset_name = args.dataset
cv_seed = int(args.cv_seed)
csv_path = f"Data/{dataset_name}.csv"

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if the dataset file exists
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"Dataset file '{csv_path}' not found.")

# Load the dataset
X_ecdf = pd.read_csv(csv_path).values.astype(np.float32)

# Split into train and test sets
X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.2, random_state=cv_seed)
if dataset_name == 'mnist_ecdf':  
    X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.5, random_state=cv_seed)


# Dataloaders
X_train_tensor = torch.tensor(X_ecdf_train, dtype=torch.float32)
train_dataset = TensorDataset(X_train_tensor)

batch_size =  128  # Adjust batch size as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



# Modelling with Hamiltonian Flow Copula

MIN_SUPPORT = 0
MAX_SUPPORT = 1
MAX_TIME = 1.5
if dataset_name == 'cifar_ecdf':
    MAX_TIME = 2.5
elif dataset_name == 'mnist_ecdf':
    MAX_TIME = 2.0
DATA_DIMS = X_ecdf_train.shape[1]

print('training velocity field, dataset:', dataset_name, 'cv_seed:', cv_seed , 'epochs:', args.epochs, 'model_type:', args.model_type)

training_start_time = time.time()

# training arguments
lr = 1e-4
iterations = int(args.epochs)
print_every = 50

print(str(dataset_name),str(args.model_type), str(args.model_type)=='Unet')

# velocity field model init
if str(args.model_type) == 'MLP':
    vf = HGF_copula(input_dim=DATA_DIMS, time_dim=1, hidden_dim=512, num_layers=6).to(device) 
elif str(args.model_type) == 'Unet':
    
    unet = My_UNet(
    T=1000,
    ch=64,
    ch_mult=[1, 2],
    attn=[1],
    num_res_blocks=2,
    dropout=0.1,
    in_channels=1,
    out_channels=1
    )

    vf = UnetWrapper(unet, X_ecdf.shape[1]).to(device)

elif str(args.model_type) == 'CNN_Flow':

    class Swish(nn.Module):
        def forward(self, x):
            return x * torch.sigmoid(x)

    class HGF_copula_CNN_Flow(nn.Module):
        def __init__(self, input_shape=(1, 32, 32), time_dim=1, hidden_dim=512, act=Swish()):
            super().__init__()
            self.input_shape = input_shape
            self.input_dim = input_shape[0] * input_shape[1] * input_shape[2]
            self.time_dim = time_dim
            self.act = act

            self.cnn = nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # (16, 16, 16) for 32x32, (16, 14, 14) for 28x28
                self.act,
                nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 8, 8) for 32x32, (32, 7, 7) for 28x28
                self.act,
                nn.Flatten()
            )
            # Dynamically compute the output size after CNN
            with torch.no_grad():
                dummy = torch.zeros(1, *input_shape)
                cnn_out_dim = self.cnn(dummy).shape[1]

            self.mlp = nn.Sequential(
                nn.Linear(cnn_out_dim + time_dim, hidden_dim),
                self.act,
                nn.Linear(hidden_dim, hidden_dim),
                self.act,
                nn.Linear(hidden_dim, hidden_dim),
                self.act, 
                nn.Linear(hidden_dim, hidden_dim),
                self.act,
                nn.Linear(hidden_dim, self.input_dim)  # Final output layer
            )

        def forward(self, input, t):
            sz = input.size()
            input = input.view(-1, *self.input_shape)  # Reshape to (batch, channels, height, width)
            cnn_feat = self.cnn(input)  # shape: (batch, cnn_out_dim)
            t = t.view(-1, self.time_dim).float()
            h = torch.cat([cnn_feat, t], dim=1)
            v_pred = self.mlp(h)
            return v_pred.view(*sz[:1], -1)  # (batch, features)


    # velocity field model init
    vf = HGF_copula_CNN_Flow(input_shape=(1, 32, 32), time_dim=1, hidden_dim=512).to(device) 


print('velocity field model:', str(args.model_type), 'parameter count:', sum(p.numel() for p in vf.parameters() if p.requires_grad))

# init optimizer
optim = torch.optim.Adam(vf.parameters(), lr=lr) 
loss_tracker = tt(0.)

# train
start_time = time.time()
for i in range(int(iterations)+1):
    for X_batch in train_loader:
        optim.zero_grad() 
        # sample data (user's responsibility): in this case, (X_0,X_1) ~ pi(X_0,X_1) = N(X_0|0,I)q(X_1)
        X_0 = X_batch[0]
        V_0 = torch.randn_like(X_0)

        # sample time (user's responsibility)
        t = MAX_TIME * (torch.torch.rand(X_0.shape[0]).to(device)**4)
        t = t.to(device)
        X_0 = X_0.to(device)
        V_0 = V_0.to(device)

        # sample probability path
        X_t, V_t = simulate_forward(X_0, V_0,t, MIN_SUPPORT, MAX_SUPPORT)

        # flow matching l2 loss
        pred_error = vf(X_t,t) - V_t
        loss = torch.pow(pred_error, 2).mean() 

        # optimizer step
        loss.backward() # backward
        # clip gradients
        torch.nn.utils.clip_grad_norm_(vf.parameters(), max_norm=1.0)

        optim.step() # update
        
        loss_tracker += loss.item()
        
    with torch.no_grad():
        # log loss
        if (i+1) % print_every == 0:
            torch.save(vf.state_dict(), f'Model_weights/HFC/HFC_{dataset_name}_seed_{args.cv_seed}_iter_{i+1}_{str(args.model_type)}.pt')
        elif (i+1) % 1 == 0:
            elapsed = time.time() - start_time
            print('| iter {:6d} | {:5.2f} ms/step | loss {:8.7f} , pred_error {:8.4f}' 
                .format(i+1, elapsed, loss_tracker.item()/(1000),pred_error[-1,0].item())) 
            start_time = time.time()
            loss_tracker = tt(0.)



print('time taken to train velocity field:', time.time()-training_start_time, 's, dataset:', dataset_name, 'cv_seed:', cv_seed , 'epochs:', args.epochs, 'model_type:', args.model_type)
