import os
from time import time
from tqdm import tqdm
from typing import *
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import delu, json
import argparse

# from tqdm import tqdm
# from copy import deepcopy
import pickle

import numpy as np
import pandas as pd
# from tabsyn.baselines.tabvfm.flow_matching import *
from .flow_matching import *
from .networks import *
from .preprocess import *
from sklearn.preprocessing import QuantileTransformer, OneHotEncoder


def main(args):

    delu.random.seed(args.seed)
    curr_dir = os.path.dirname(os.path.abspath(__file__))
    dataname = args.dataname
    device = f'cuda:{args.gpu}'

    # config_path = f'{curr_dir}/configs/{dataname}.toml'
    model_save_path = f'{curr_dir}/ckpt/{dataname}'
    real_data_path = f'data/{dataname}'

    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    
    args.train = True
    # raw_config = src.load_config(config_path)

    info_path = f'data/{dataname}/info.json'
    with open(info_path, 'r') as f:
        info = json.load(f)

    data_train = pd.read_csv(f'{real_data_path}/train.csv')

    discrete_columns = [data_train.columns[i] for i in info['cat_col_idx']]
    numerical_columns = [data_train.columns[i] for i in info['num_col_idx']]
    if info['task_type'] == 'binclass': discrete_columns += data_train.columns[info['target_col_idx']].tolist()

    quantile = QuantileTransformer()
    x_cont = quantile.fit_transform(data_train[numerical_columns])

    encoder = OneHotEncoder()
    x_cat_dummy = encoder.fit_transform(data_train[discrete_columns]).toarray()
    K = [len(data_train[k].unique()) for k in discrete_columns]

    d_cont = x_cont.shape[1] if x_cont is not None else 0 
    d_cat = x_cat_dummy.shape[1]
    d_total = d_cont + d_cat

    n_epochs = 100
    batch_size = args.batch_size

    dataset = torch.cat([torch.from_numpy(x_cont), torch.from_numpy(x_cat_dummy)], dim=1).float() if d_cont > 0 else torch.from_numpy(x_cat_dummy).float()
    dataset = dataset.to(device)
    dataset = TensorDataset(dataset)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    '''
        the details are in
        https://github.com/gle-bellier/flow-matching/blob/main/Flow_Matching.ipynb
        or in the flow matching notebook in drive
    '''

    model = OTTabFlowMatching(d_cont, K)
    net = Net(d_total, 512).to(device) ## already according to tabsyn paper
    net_t = Regular_t().to(device)
    # net_t = VPdiffusion_t().to(device)
    v_t = CondVF(net, net_t, d_cont = d_cont, cat_list = K)
    losses = []
    per_batch_losses = []
    # configure optimizer
    optimizer = torch.optim.AdamW(v_t.parameters(), lr=1e-3)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=True)

    print(''.center(30,'#'))
    print('Training starts'.center(30,'-'))

    epoch_iterator = tqdm(range(n_epochs))
    for epoch in epoch_iterator:
        # batch_losses = []
        batch_loss = 0.0
        len_input = 0
        for batch in dataloader:
            # for _ in range(1):
            #     with optimizer.sampled_params(train=True):
            x_1 = batch[0].to(device)
            # compute loss
            loss = model.loss(v_t, x_1)
            batch_loss += loss.item() * x_1.size(0)
            len_input += x_1.size(0)
            loss.backward()
            losses += [[epoch, loss.detach().cpu().numpy()]]
            optimizer.step()
            optimizer.zero_grad()
        
        curr_loss = batch_loss/len_input
        scheduler.step(curr_loss)

        if (epoch+1) % args.every == 0 and (epoch+1) != n_epochs:
            if not os.path.exists(model_save_path+'/per_epoch'):
                os.makedirs(model_save_path+'/per_epoch')
            torch.save(v_t.net.state_dict(), f'{model_save_path}/per_epoch/vf_{epoch+1}.pt')
        
    torch.save(v_t.net.state_dict(), f'{model_save_path}/vf.pt')
    with open(f"{model_save_path}/vf_loss.pkl", "wb") as file:
        pickle.dump(losses, file)

    print('Training finished'.center(30,'-'))
    print(''.center(30,'#'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--config', metavar='FILE')
    parser.add_argument('--dataname', type = str, default = 'adult')
    parser.add_argument('--gpu', type = int, default=0, help='GPU device number.')

    args = parser.parse_args()

