import argparse
import os
import numpy as np
import anndata as ad
import pandas as pd
import pdb
import torch
import wandb

from gmmot.utils.config_tools import get_paths
from gmmot.utils.data_tools import scLoader
from gmmot.utils.data_prep import get_data
from gmmot.vae.main import VAE


parser = argparse.ArgumentParser()
parser.add_argument("--latent_dim", default=16, type=int, help="latent dimension")
parser.add_argument("--fc_dim", default=100, type=int, help="number of nodes at the hidden layers")
parser.add_argument("--n_layer", default=4, type=int, help="number of hidden layers")
parser.add_argument("--batch_size", default=1024, type=int, help="batch size")
parser.add_argument("--beta",  default=1, type=float, help="KL regularization parameter")
parser.add_argument("--lr", default=.0001, type=float, help="learning rate")
parser.add_argument("--n_epoch", default=10, type=int, help="number of epochs")
parser.add_argument("--n_gene", default=0, type=int, help="number of genes")
parser.add_argument("--p_drop", default=0.25, type=float, help="input probability of dropout")
parser.add_argument("--variational", default=False, action="store_true", help="enable variational mode")
parser.add_argument("--loss_mode", default='MSE-BCE', type=str, help="loss function, either MSE or MSE-BCE")
parser.add_argument("--n_run", default=1, type=int, help="index of the run")
parser.add_argument("--toml_file", default='pyproject.toml', type=str, help="path to the toml file")
parser.add_argument("--class_label", default='', type=str, help="transcriptomic class label, e.g., 'OPC' ")
parser.add_argument("--network", default='deep', type=str, help="type of the model, either 'shallow' or 'deep'")
parser.add_argument("--prep_data", default=False, action="store_true", help="prepare data if not already done")
parser.add_argument("--cuda", default=False, action="store_true", help="gpu device, use None for cpu")
parser.add_argument("--ws", default=1, type=int, help="world size for distributed training")
parser.add_argument("--nw", default=4, type=int, help="number of workers (CPUs)")
parser.add_argument("--dist_samp", default=False, action="store_true", help="use distributed sampler")
parser.add_argument("--prefetch_factor", default=2, type=int, help="prefetch factor")
parser.add_argument("--random_seed", default=0, type=int, help="random seed")
parser.add_argument("--use_wandb", default=False, action="store_true", help="use wandb for logging")
parser.add_argument("--device_num", default=0, type=int, help="cuda device number")


def main(args):
    rank = 0
    if args.cuda==False:
        device = torch.device("cpu")
    else:
        free_gpus = []
        for i in range(torch.cuda.device_count()):
            if torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) > 0:
                free_gpus.append(i)
        if free_gpus:
            if args.ws == 1:
                device = torch.device(f"cuda:{free_gpus[args.device_num]}")
            else:
                device = torch.device(f"cuda:{free_gpus[rank]}")
            print('---> Using GPU(s): ' + torch.cuda.get_device_name(device))
        else:
            raise RuntimeError("No free GPU devices available.")

    
    dataset = 'ageing'
    config = get_paths(toml_file=args.toml_file, sub_file='scrna')
    
    if args.prep_data:
        print('Preparing the data...')
        adata = get_data(class_label=args.class_label, n_gene=args.n_gene, saving=True)
    else:
        data_file_name = f'{args.class_label}_file' if args.class_label else 'file_h5ad'
        data_file = config['paths']['data_path'] / dataset / config['scrna'][data_file_name]
        adata = ad.read_h5ad(data_file, backed='r+')
        
    genes = adata.var.index.values
    n_gene = len(genes)
    print(f'Number of cells: {adata.X.shape[0]}')
    print(f'Number of genes: {adata.X.shape[1]}')
    
    if args.class_label:
        col = 'class_label'
        unique_labels = adata.obs[col].unique()
        types = [st for st in unique_labels if st.startswith(args.class_label)]
        adata = adata[adata.obs[col].isin(set(types)).values]
    else:
        args.class_label = 'all'
        
        
    n_smp = adata.X.shape[0]
    
    print(f'Dataset for {args.class_label} has {n_smp} samples with {len(genes)} genes.')

    folder_name = f'run_{args.n_run}_zDim_{args.latent_dim}_p_drop_{args.p_drop}_fc_dim_{args.fc_dim}_' + \
                  f'nbatch_{args.batch_size}_variational_{args.variational}_nepoch_{args.n_epoch}_{args.class_label}_loss_{args.loss_mode}'

    saving_path = config['paths']['main_dir'] / config['paths']['saving_path'] / 'mouse'
    saving_folder = saving_path / folder_name
    os.makedirs(saving_folder, exist_ok=True)
    os.makedirs(saving_folder / 'model', exist_ok=True)
    
    with open(saving_folder / 'parameter.text', "w") as f:
        for key, value in vars(args).items():
            f.write(f"{key}: {value}\n")

    # print(f'Number of NaN values: {np.sum(np.isnan(adata.X[:, :n_gene].toarray()).any(axis=1))}')
    
    print('preparing the data loaders')
    train_loader, test_loader, _ = scLoader(
                                            adata=adata, 
                                            features=range(n_gene),
                                            batch_size=args.batch_size,
                                            random_seed=args.random_seed,
                                            world_size=args.ws,
                                            rank=rank,
                                            use_dist_sampler=args.dist_samp,
                                            num_workers=args.nw,
                                            prefetch_factor=args.prefetch_factor,
                                            )
    adata.file.close()
    del adata
    
    # Train and save the model
    if args.use_wandb:
        run = wandb.init(project=f"mouse-ageing-experiments_{args.class_label}")
    else:
        run = None
        
    print('training the model')
    model = VAE(saving_folder=saving_folder, device=device)

    model.init_nn(
                input_dim=n_gene, 
                network=args.network,
                fc_dim=args.fc_dim, 
                lowD_dim=args.latent_dim, 
                n_layer=args.n_layer,
                x_drop=args.p_drop, 
                variational=args.variational,
                )
    
    model.train(
                train_loader=train_loader, 
                validation_loader=test_loader, 
                mode=args.loss_mode,
                lr=args.lr, 
                beta=args.beta, 
                n_epoch=args.n_epoch,
                wandb_run=run,
                )
    

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)