import argparse
import sys
import json
import os
import torch
from torch.utils.data import DataLoader, random_split
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping, Timer, ModelCheckpoint
from abcdefg import CGM
from abcdefg.util import hash_time
import abcdefg.config as config
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from single_cell import PerturbDataset


parser = argparse.ArgumentParser('ABCDEFG-Single-Cell')
######################################################################
# File IO
######################################################################
parser.add_argument(
    '--anndata', type=str, required=True,
    help='Input AnnData (.h5ad).'
)
parser.add_argument(
    '--save-data', type=str,
    help='Output directory for saving data.'
)
parser.add_argument(
    '--load-data', type=str,
    help='Output directory for loading data.'
)
parser.add_argument(
    '--log-dir', type=str, required=True,
    help='Log directory for saving training results and model checkpoints.'
)
parser.add_argument(
    '--data-id', type=int, default=0,
    help='ID of the dataset.'
)

######################################################################
# AnnData
######################################################################
parser.add_argument(
    '--preprocess', action="store_true",
    help='Preprocess AnnData.'
)
parser.add_argument(
    '--perturb-key', type=str,
    help='Key of perturbation information in AnnData.obs.'
)
parser.add_argument(
    '--n-genes', type=int, default=1000,
    help='Number of genes to keep in preprocessing.'
)
parser.add_argument(
    '--gene-perturbation', action="store_true",
    help='Perturbation targets specific genes instead of having indirect effects.'
)

######################################################################
# Model hyperparameters
######################################################################
parser.add_argument(
    '-m', '--num_factors', required=True, type=int,
    help='Number of factors'
)
parser.add_argument(
    '--num_layers', type=int, default=2,
    help='Number of encoder layers'
)
parser.add_argument(
    '--num_dec_layers', type=int, default=2,
    help='Number of decoder layers'
)
parser.add_argument(
    '--hid-dim', type=int, default=1000,
    help='Number of hidden units in neural networks'
)
parser.add_argument(
    '--graph-model', type=str, default='SPNFG',
    help='Name of the graph model.'
)
parser.add_argument(
    '--fix_factor', action="store_true",
    help='Set the factors to be deterministic <=> Use an AE instead of a VAE.'
)
parser.add_argument(
    '--mu_z', type=float, default=0.0,
    help='Mean of prior distribution of z'
)
parser.add_argument(
    '--std_z', type=float, default=1e-3,
    help='Standard deviation of prior distribution of z'
)
parser.add_argument(
    '--p_edge', type=float, default=0.5,
    help='Prior probability of node-to-factor connection.'
)
parser.add_argument(
    '--var_type', type=str, default='const',
    help="Type of node variables. Must be either 'const' or 'learned'."
)
parser.add_argument(
    '--noise_level', type=float, default=0.05,
    help='Standard deviation of Gaussian noise in node observations.'
)
parser.add_argument(
    '--nonlin', type=str, default='relu',
    help='Nonlinearity in neural networks.'
)
parser.add_argument(
    '--loss_memory', type=float, default=0.9,
    help='Memory parameter in loss computation.'
)

######################################################################
# SPNFG parameters
######################################################################
parser.add_argument(
    '--spn_target', type=str, default='node',
)
parser.add_argument(
    '--max_copies', type=int, default=8,
    help='Maximum size of cartesian products in sum-product networks.'
)
parser.add_argument(
    '--tau', type=float, default=1.0,
    help='Temperature parameter in Gumbel-Softmax.'
)
parser.add_argument(
    '--p_conn', type=float, default=0.5,
    help='Desired node-to-factor connection probability in sum-product networks.'
)
parser.add_argument(
    '--sparsity_temp', type=float, default=0.0,
    help='Temperature parameter in sparsity regularization.'
)
parser.add_argument(
    '--p_conn_int', type=float, default=0.5,
    help='Desired node-to-factor connection probability in sum-product networks (for interventions).'
)
parser.add_argument(
    '--sparsity_temp_int', type=float, default=0.0,
    help='Temperature parameter in sparsity regularization (for interventions).'
)

######################################################################
# Intervention
######################################################################
parser.add_argument(
    '--intv-target', type=str, default='factor',
    help='Whether intervention targets are genes or factors.'
)
parser.add_argument(
    '--known-intervention', action="store_true",
    help='Whether intervention targets are known.'
)
parser.add_argument(
    '--hard-intervention', action="store_true",
    help='Whether intervention is soft or hard.'
)
parser.add_argument(
    '--hold-out-pert', type=float, default=0.0,
    help='Proportion of perturbations to hold out for testing.'
)

######################################################################
# Training hyperparameters
######################################################################
parser.add_argument(
    '-e', '--max_epochs', type=int, default=2000,
    help='Maximum number of training epochs.'
)
parser.add_argument(
    '--min_epochs', type=int, default=10,
    help='Minimum number of training epochs.'
)
parser.add_argument(
    '--every_n_epochs', type=int, default=10,
    help='Number of epochs between checkpoints'
)
parser.add_argument(
    '--patience', type=int, default=100,
    help='Patience for early stopping.'
)
parser.add_argument(
    '--stop_thred', type=float, default=1e-3,
    help='Threshold of ELBO change for early stopping.'
)
parser.add_argument(
    '--p-train', type=float, default=0.8,
    help='Proportion of training data.'
)
parser.add_argument(
    '--num_workers', type=int, default=0,
    help='Number of parallel workers for data loading.'
)
parser.add_argument(
    '-b', '--batch_size', type=int, default=128,
    help='Batch size.'
)
parser.add_argument(
    '--lr_nn', type=float, default=1e-4,
    help='Learning rate for neural networks.'
)
parser.add_argument(
    '--lr_fg', type=float, default=1e-3,
    help='Learning rate for factor graph model.'
)
parser.add_argument(
    '--weight_decay', type=float, default=1e-3,
    help='Weight decay.'
)
parser.add_argument(
    '--seed', type=int, default=42,
    help='Random seed.'
)
parser.add_argument(
    '--coeff_kl_z', type=float, default=1e-6,
    help='Coefficient of KL divergence term for z.'
)
parser.add_argument(
    '--coeff_kl_fg', type=float, default=1e-6,
    help='Coefficient of KL divergence term for the factor graph.'
)
parser.add_argument(
    '--coeff_l1_reg', type=float, default=1.0,
    help='L1 regularization coefficient of graph density.'
)
parser.add_argument(
    '--coeff_l1_reg_int', type=float, default=10.0,
    help='L1 regularization coefficient of intervention-to-factor density.'
)
parser.add_argument(
    '--robust-loss', action="store_true",
    help='Use Huber loss to approximate Gaussian likelihood.'
)
parser.add_argument(
    '--early-stop', action="store_true",
    help='Early stopping.'
)
parser.add_argument(
    '--verbose', action="store_true",
    help='Used for debugging.'
)


def build_model(
    num_vars,
    num_intv,
    args
):
    z_prior = {
        'mean': args.mu_z,
        'std': args.std_z
    }
    loss_coeff = {
        'kl_z': args.coeff_kl_z,
        'kl_fg': args.coeff_kl_fg,
        'l1_reg': args.coeff_l1_reg,
        'l1_reg_int': args.coeff_l1_reg_int
    }
    config.set_model_params(
        graph_model=args.graph_model,
        fix_factor=args.fix_factor,
        z_prior=z_prior,
        edge_prior=args.p_edge,
        nonlin=args.nonlin,
        lr_nn=args.lr_nn,
        lr_fg=args.lr_fg,
        loss_coeff=loss_coeff,
        loss_memory=args.loss_memory,
        var_type=args.var_type,
        noise_level=args.noise_level,
        robust_loss=args.robust_loss
    )
    if args.graph_model == 'SPNFG':
        p_conn = args.p_conn if args.known_intervention else (args.p_conn, args.p_conn_int)
        sparsity_temp = (
            args.sparsity_temp if args.known_intervention else (args.sparsity_temp, args.sparsity_temp_int)
        )
        config.set_graph_params(
            args.graph_model,
            tau=args.tau,
            spn_target=args.spn_target,
            max_copies=args.max_copies,
            p_conn=p_conn,
            sparsity_temp=sparsity_temp
        )
    else:
        config.set_graph_params(args.graph_model, tau=args.tau)

    config.set_optim_params(weight_decay=args.weight_decay)

    model = model = CGM(
        num_vars,
        num_intv,
        args.num_factors,
        args.num_layers,
        args.num_dec_layers,
        args.hid_dim,
        args.batch_size,
        **config.MODEL_HYPER_PARAMS,
    )
    return model


def train():
    # Preparation
    torch.set_float32_matmul_precision('highest')
    args = parser.parse_args()
    if args.save_data is not None:
        os.makedirs(args.save_data, exist_ok=True)
    config.VERBOSE = args.verbose

    # Set random seed
    seed_everything(args.seed)

    # Load dataset
    print('Loading dataset.')
    dataset = PerturbDataset(
        args.anndata,
        args.data_id,
        args.perturb_key,
        args.preprocess,
        args.gene_perturbation,
        args.known_intervention,
        args.hard_intervention,
        args.hold_out_pert,
        load_test=False,
        save_path=args.save_data,
        load_path=args.load_data
    )

    n_train = int(len(dataset) * args.p_train)
    n_val = len(dataset) - n_train
    train_data, val_data = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    # Test data
    test_data = None
    test_loader = None
    if args.hold_out_pert > 0:
        test_load_path = args.load_data if args.load_data is not None else args.save_data
        test_data = PerturbDataset(
            args.anndata,
            args.data_id,
            args.perturb_key,
            args.preprocess,
            args.gene_perturbation,
            args.known_intervention,
            args.hard_intervention,
            args.hold_out_pert,
            load_test=True,
            load_path=test_load_path
        )
        test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
        print(f'Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}')
    else:
        print('No held-out perturbation.\n')
        print(f'Train: {len(train_data)}, Validation: {len(val_data)}')

    # Initialize model
    print('Creating model.')
    model = build_model(
        dataset.n_gene, dataset.n_intv, args
    )

    # Define callbacks
    timer = Timer()
    checkpoint_saver = ModelCheckpoint(
        every_n_epochs=args.every_n_epochs,
        save_top_k=3,
        monitor="val_loss_accum",
    )
    callbacks = [timer, checkpoint_saver]
    if args.early_stop:
        early_stopping = EarlyStopping(
            'val_loss_accum',
            min_delta=args.stop_thred,
            patience=args.patience,
            mode='min'
        )
        callbacks.append(early_stopping)

    # Define logger
    version = hash_time()
    name = 'ABCDEFG_Basic' if args.graph_model == 'BasicFG' else 'ABCDEFG_SPN'
    log_dir = os.path.join(args.log_dir, name, version)
    csv_logger = CSVLogger(
        save_dir=args.log_dir, name=name, version=version
    )

    # Build a trainer
    trainer = Trainer(
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=-1 if torch.cuda.is_available() else 1,
        precision=32,
        logger=csv_logger,
        min_epochs=args.min_epochs,
        max_epochs=args.max_epochs,
        callbacks=callbacks
    )

    # Train model
    print('Start training.')
    trainer.fit(model, train_loader, val_loader)

    # Test the best models
    if test_loader is not None:
        print('Start testing.')
        trainer.test(model, test_loader)

    # Save metrics and hyperparameters
    run_time = {
        'Training': timer.time_elapsed('train'),
        'Validation': timer.time_elapsed('validate'),
        'Testing': timer.time_elapsed('test'),
    }
    with open(f'{log_dir}/run_time.json', 'w') as f:
        json.dump(run_time, f)
    with open(f'{log_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    
    # Save the graph
    model.save_graph(log_dir)


if __name__ == '__main__':
    train()
