import argparse
import json
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping, Timer, ModelCheckpoint
from abcdefg import CGM
from abcdefg.util import hash_time
from abcdefg.data import SimulationDataset, SimulationDatasetIFG
import abcdefg.config as config


parser = argparse.ArgumentParser('ABCDEFG')
# File IO
parser.add_argument('-i', '--in-dir', type=str, required=True, help='Input directory')
parser.add_argument('-g', '--graph', type=str, default='DAG.npy', help='File name of the graph.')
parser.add_argument('-o', '--out-dir', type=str, required=True, help='Output directory')
parser.add_argument('-d', '--data_id', type=int, required=True, help='ID of the dataset.')
# Model hyperparameters
parser.add_argument('-m', '--num_factors', required=True, type=int, help='Number of factors')
parser.add_argument('--num_layers',  type=int, default=2, help='Number of encoder layers')
parser.add_argument('--num_dec_layers', type=int, default=2, help='Number of decoder layers')
parser.add_argument('--hid_dim', type=int, default=1000, help='Number of hidden units in neural networks')
parser.add_argument('--graph-model', type=str, default='SPNFG', help='Name of the graph model.')
parser.add_argument('--fix-mask', action="store_true", help='Randomly sample causal mask using Gumbel softmax during training.')
parser.add_argument('--soft-mask', action="store_true", help='Use hard mask in the VAE model.')
parser.add_argument('--fix-factor', action="store_true",
                    help='Set the factors to be deterministic <=> Use an AE instead of a VAE.')
parser.add_argument('--mu-z', type=float, default=0.0, help='Mean of prior distribution of z')
parser.add_argument('--std-z', type=float, default=1e-3, help='Standard deviation of prior distribution of z')
parser.add_argument('--p-edge', type=float, default=0.5, help='Prior probability of node-to-factor connection.')
parser.add_argument('--var_type', type=str, default='const',
                    help="Type of node variables. Must be either 'const' or 'learned'.")
parser.add_argument('--noise-level', type=float, default=0.05,
                    help='Standard deviation of Gaussian noise in node observations.')
parser.add_argument('--nonlin', type=str, default='relu', help='Nonlinearity in neural networks.')
parser.add_argument('--loss-memory', type=float, default=0.5, help='Memory parameter in loss computation.')
# Graph parameters
parser.add_argument('--spn-target', type=str, default='factor')
parser.add_argument('--max-copies', type=int, default=8,
                    help='Maximum size of cartesian products in sum-product networks.')
parser.add_argument('--tau', type=float, default=1.0, help='Temperature parameter in Gumbel-Softmax.')
parser.add_argument('--p-conn', type=float, default=0.5,
                    help='Desired node-to-factor connection probability in sum-product networks.')
parser.add_argument('--sparsity_temp', type=float, default=0.0, help='Temperature parameter in sparsity regularization.')
parser.add_argument('--p-conn-int', type=float, default=0.5,
                    help='Desired node-to-factor connection probability in sum-product networks (for interventions).')
parser.add_argument('--sparsity_temp_int', type=float, default=0.0,
                    help='Temperature parameter in sparsity regularization  (for interventions).')
# Intervention
parser.add_argument('--disable-intv', action="store_true", help='disable intervention.')
parser.add_argument('--intv-target', type=str, default='node', help='Whether intervention targets are nodes or factors.')
parser.add_argument('--num-untargeted-intv', type=int, default=0, help='Number of untargeted interventions')
parser.add_argument('--soft_intervention', action="store_true", help='Whether intervention is soft or hard.')
# Training hyperparameters
parser.add_argument('-e', '--max_epochs', type=int, default=3000, help='Maximum number of training epochs.')
parser.add_argument('--min_epochs', type=int, default=10, help='Minimum number of training epochs.')
parser.add_argument('--every_n_epochs', type=int, default=10, help='Number of epochs between checkpoints')
parser.add_argument('--patience', type=int, default=50, help='Patience for early stopping.')
parser.add_argument('--stop_thred', type=float, default=1e-3, help='Threshold of ELBO change for early stopping.')
parser.add_argument('--n_train', type=float, default=0.8, help='Proportion of training data.')
parser.add_argument('--num_workers', type=int, default=0, help='Number of parallel workers for data loading.')
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Batch size.')
parser.add_argument('--lr_nn', type=float, default=5e-4, help='Learning rate for neural networks.')
parser.add_argument('--lr_fg', type=float, default=5e-3, help='Learning rate for factor graph model.')
parser.add_argument('--weight_decay', type=float, default=1e-3, help='Weight decay.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--coeff_kl_z', type=float, default=1e-8, help='Coefficient of KL divergence term for z.')
parser.add_argument('--coeff_kl_fg', type=float, default=1e-8,
                    help='Coefficient of KL divergence term for the factor graph.')
parser.add_argument('--coeff_l1_reg', type=float, default=1.0, help='L1 regularization coefficient of graph density.')
parser.add_argument('--coeff_l1_reg_int', type=float, default=10.0,
                    help='L1 regularization coefficient of intervention-to-factor density.')
parser.add_argument('--coeff-scheduler', type=str, default='sigmoid', help='Scheduler for loss coefficients.')
parser.add_argument('--coeff-scheduler-params', type=str, default=None, help='Parameters (.json) for the scheduler.')
parser.add_argument('--robust-loss', action="store_true", help='Use Huber loss to approximate Gaussian likelihood.')
parser.add_argument('--early-stop', action="store_true", help='Early stopping.')
parser.add_argument('--verbose', action="store_true", help='Used for debugging.')


def build_model(num_vars, args):
    z_prior = {
        'mean': args.mu_z,
        'std': args.std_z
    }
    loss_coeff = {
        'kl_z': args.coeff_kl_z,
        'kl_fg': args.coeff_kl_fg,
        'l1_reg': args.coeff_l1_reg,
        'l1_reg_int': args.coeff_l1_reg_int
    }
    config.set_model_params(
        graph_model=args.graph_model,
        sample_mask=not args.fix_mask,
        hard_mask=not args.soft_mask,
        fix_factor=args.fix_factor,
        z_prior=z_prior,
        edge_prior=args.p_edge,
        nonlin=args.nonlin,
        lr_nn=args.lr_nn,
        lr_fg=args.lr_fg,
        loss_coeff=loss_coeff,
        coeff_scheduler=args.coeff_scheduler,
        loss_memory=args.loss_memory,
        var_type=args.var_type,
        noise_level=args.noise_level,
        robust_loss=args.robust_loss
    )
    
    # set scheduler parameters
    if args.coeff_scheduler_params is not None:
        with open(args.coeff_scheduler_params, 'r') as f:
            scheduler_params = json.load(f)
        config.set_scheduler_params(args.coeff_scheduler, **scheduler_params[args.coeff_scheduler])
    if args.graph_model == 'SPNFG':
        p_conn = args.p_conn if args.num_untargeted_intv == 0 else (args.p_conn, args.p_conn_int)
        sparsity_temp = (
            args.sparsity_temp if args.num_untargeted_intv == 0 else (args.sparsity_temp, args.sparsity_temp_int)
        )
        config.set_graph_params(
            args.graph_model,
            tau=args.tau,
            spn_target=args.spn_target,
            max_copies=args.max_copies,
            p_conn=p_conn,
            sparsity_temp=sparsity_temp
        )
    else:
        config.set_graph_params(args.graph_model, tau=args.tau)

    config.set_optim_params(weight_decay=args.weight_decay)

    model = model = CGM(
        num_vars,
        args.num_untargeted_intv,
        args.num_factors,
        args.num_layers,
        args.num_dec_layers,
        args.hid_dim,
        args.batch_size,
        **config.MODEL_HYPER_PARAMS,
    )
    return model


def train():
    # Preparation
    torch.set_float32_matmul_precision('highest')
    args = parser.parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    config.VERBOSE = args.verbose

    # Set random seed
    seed_everything(args.seed)

    # True graph
    true_graph = None
    if os.path.exists(f'{args.in_dir}/{args.graph}'):
        true_graph = np.load(f'{args.in_dir}/{args.graph}')
    true_graph_int = None
    unknown_intv = args.num_untargeted_intv > 0
    if os.path.exists(f'{args.in_dir}/DAG_int{args.data_id}.npy') and unknown_intv:
        true_graph_int = np.load(f'{args.in_dir}/DAG_int{args.data_id}.npy')

    # Load dataset
    print('Loading dataset.')
    dataset_class = SimulationDatasetIFG if args.intv_target == 'factor' else SimulationDataset
    dataset = dataset_class(
        args.in_dir,
        args.data_id,
        intervention=not args.disable_intv,
        unknown_intervention=unknown_intv,
        soft_intervention=args.soft_intervention,
        load_test=False
    )
    n_train = int(len(dataset) * args.n_train)
    n_val = len(dataset) - n_train
    train_data, val_data = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    # Test data
    test_data = dataset_class(
        args.in_dir,
        args.data_id,
        intervention=not args.disable_intv,
        soft_intervention=args.soft_intervention,
        load_test=True
    )
    test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    print(f'Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}')

    # Initialize model
    print('Creating model.')
    num_vars = dataset.data.shape[1] - args.num_untargeted_intv
    model = build_model(num_vars, args)
    model.set_true_graph(true_graph, true_graph_int)

    # Define callbacks
    timer = Timer()
    checkpoint_saver = ModelCheckpoint(
        every_n_epochs=args.every_n_epochs,
        save_top_k=3,
        monitor="val_loss_accum",
    )
    callbacks = [timer, checkpoint_saver]
    if args.early_stop:
        early_stopping = EarlyStopping(
            'val_loss_accum',
            min_delta=args.stop_thred,
            patience=args.patience,
            mode='min'
        )
        callbacks.append(early_stopping)

    # Define logger
    version = hash_time()
    name = 'ABCDEFG_Basic' if args.graph_model == 'BasicFG' else 'ABCDEFG_SPN'
    log_dir = os.path.join(args.out_dir, name, version)
    csv_logger = CSVLogger(
        save_dir=args.out_dir, name=name, version=version
    )

    # Build a trainer
    trainer = Trainer(
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=-1 if torch.cuda.is_available() else 1,
        precision=32,
        logger=csv_logger,
        min_epochs=args.min_epochs,
        max_epochs=args.max_epochs,
        callbacks=callbacks
    )

    # Train model
    print('Start training.')
    trainer.fit(model, train_loader, val_loader)

    # Test the best models
    trainer.test(model, test_loader)

    # Save metrics and hyperparameters
    run_time = {
        'Training': timer.time_elapsed('train'),
        'Validation': timer.time_elapsed('validate'),
        'Testing': timer.time_elapsed('test'),
    }
    with open(f'{log_dir}/run_time.json', 'w') as f:
        json.dump(run_time, f)
    with open(f'{log_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    
    # Save the graph
    model.save_graph(log_dir)


if __name__ == '__main__':
    train()
