import argparse
import torch
import torch.multiprocessing as mp
import os
from torch.utils.data import DataLoader, random_split
from latent_lang_diff.text_denoising_diffusion import GaussianDiffusion, Trainer
from latent_lang_diff.models import DiffusionTransformer
from dataset.dataset import create_df_dict_from_dir, create_config_dict_from_dir
from tableLatent.tableVAETransformer import TableVAETransformer
from sdmetrics.reports.single_table import QualityReport


os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Constants
ATTN_HEAD_DIM = 64

def run_distributed(rank, world_size, distributed, args):
    """
    Each process runs this function for distributed training.
    """
    # Set up the VAE transformer
    transformer = TableVAETransformer()
    if args.vae_checkpoint is not None:
        loaded = torch.load(args.vae_checkpoint, map_location=lambda storage, loc: storage.cuda(rank) if torch.cuda.is_available() else storage)
        transformer.load_checkpoint(loaded)

    # Load data and configuration
    config_dict = create_config_dict_from_dir(args.config_dir)
    train_val_df_dict, test_df_dict = create_df_dict_from_dir(args.csv_dir, test_size=args.validation_rate)
    #print(config_dict)

    # Fit the transformer and prepare datasets
    transformer.fit(train_val_df_dict, test_df_dict, config_dict,
                    retrain_vae=(args.retrain_vae == 'both'),
                    retrain_decoder_only=(args.retrain_vae == 'decoder'),
                    num_epochs=args.retrain_vae_epochs)

    # Save created dataset to save time
    dataset_dir = args.saved_dataset_dir
    if dataset_dir is None:
        train_val_dataset = transformer.df_to_latent(train_val_df_dict, return_format='dataset')
    else:
        train_val_dataset = torch.load(dataset_dir) 
        print("Dataset loaded from checkpoint:",dataset_dir)

    if args.saved_embedded_dataset:
        torch.save(train_val_dataset, dataset_dir)

    # Split the dataset into training and validation sets
    train_size = int(0.8 * len(train_val_dataset))
    val_size = len(train_val_dataset) - train_size
    train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=args.train_batch_size, shuffle=True)
    train_val_dataloader = DataLoader(train_val_dataset, batch_size=args.train_batch_size, shuffle=True)
    print(len(train_dataloader))

    # Define the diffusion transformer model
    model = DiffusionTransformer(
        tx_dim=args.tx_dim,
        tx_depth=args.tx_depth,
        heads=args.tx_dim // ATTN_HEAD_DIM,
        latent_dim=args.tx_dim,
        max_seq_len=args.max_seq_len,
        self_condition=args.self_condition,
        scale_shift=args.scale_shift,
        dropout=0 if args.disable_dropout else 0.1,
        class_conditional=False,
        seq2seq=True,
        seq2seq_context_dim=args.seq2seq_context_dim
    ).cuda(rank)

    # Set up the diffusion process
    diffusion = GaussianDiffusion(
        model=model,
        max_seq_len=args.max_seq_len,
        sampling_timesteps=args.sampling_timesteps,
        sampler=args.sampler,
        train_schedule=args.train_schedule,
        sampling_schedule=args.sampling_schedule,
        loss_type=args.loss_type,
        objective=args.objective,
        train_prob_self_cond=args.train_prob_self_cond,
        seq2seq_unconditional_prob=args.seq2seq_unconditional_prob,
        scale=args.scale,
    ).cuda(rank)

    # Initialize the trainer
    trainer = Trainer(
        args=args,
        diffusion=diffusion,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        train_val_dataloader=train_val_dataloader,
        train_batch_size=args.train_batch_size,
        eval_batch_size=args.eval_batch_size,
        gradient_accumulate_every=args.gradient_accumulation_steps,
        train_lr=args.learning_rate,
        train_num_steps=args.num_train_steps,
        lr_schedule=args.lr_schedule,
        num_warmup_steps=args.lr_warmup_steps,
        ema_update_every=args.ema_update_every,
        ema_decay=args.ema_decay,
        adam_betas=(args.adam_beta1, args.adam_beta2),
        adam_weight_decay=args.adam_weight_decay,
        save_and_sample_every=args.save_and_sample_every,
        num_samples=args.num_samples,
        results_folder=args.output_dir,
        amp=args.amp,
        mixed_precision=args.mixed_precision,
        distributed=distributed,
        rank=rank,
        world_size=world_size
    )

    if args.diffusion_checkpoint is not None:
        trainer.load(args.diffusion_checkpoint,best=False,init_only=args.init_only_diff=='yes')

    # Start training
    if args.train_diffusion == 'yes':
        trainer.train(train_dataloader, val_dataloader)
        if args.diffusion_checkpoint is None:
            model_name = "CTSyn"
        else:
            model_name = "CTSynFinetuned"
    else:
        model_name = "CTSynZeroshot"

    # Generate synthetic data if required
    if args.job_type in ['train_gen', 'gen']:
        for dataset_name in train_val_df_dict:
            df = train_val_df_dict[dataset_name]
            one_df_dict = {dataset_name: df}
            df_dataloader = transformer.df_to_latent(one_df_dict)

            pred_latents = trainer.sample_seq2seq(split="train", return_sample=True,num_samples=args.num_samples)
            recon_df = transformer.latent_to_df(pred_latents, dataset_name)

            test_df = test_df_dict[dataset_name]
            report = QualityReport()
            column_types = {col: {'sdtype': 'categorical'} if test_df[col].dtype == 'object' else {'sdtype': 'numerical'} for col in test_df.columns}

            metadata = {
                "primary_key": "user_id",
                "columns": column_types
            }

            report.generate(test_df, recon_df, metadata)
            print(report.get_details(property_name='Column Shapes'))
            print(report.get_details(property_name='Column Pair Trends'))

            full_path = os.path.join(args.synth_data_dir, f"{dataset_name}_{model_name}_default_{args.test_idx}.csv")
            recon_df.to_csv(full_path, index=False)

    # Save the trained model
    trainer.save(args.save_dir)

def parse_args():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(description="Distributed Diffusion Training")
    parser.add_argument("--test_idx", type=int, default=0)
    parser.add_argument("--job_type", type=str, default="train_gen", choices=['train_gen', 'train', 'gen'])

    parser.add_argument("--csv_dir", type=str, default=None, help="Folder of CSV files")
    parser.add_argument("--config_dir", type=str, default=None, help="Folder of configuration files.")
    parser.add_argument("--saved_dataset_dir", type=str, default=None, help="Path of saved embedding dataset")
    parser.add_argument("--saved_embedded_dataset", action="store_true", default=False)
    parser.add_argument("--validation_rate", type=float, default=0.1)

    parser.add_argument("--vae_checkpoint", type=str, default=None, help="File path of pre-trained VAE")
    parser.add_argument("--synth_data_dir", type=str, default="synth_data", help="Folder for synthetic data.")
    parser.add_argument("--retrain_vae", type=str, default="no", choices=["no", "both", "decoder"])
    parser.add_argument("--retrain_vae_epochs", type=int, default=50)
    parser.add_argument("--diffusion_checkpoint", type=str, default=None, help="File path of pre-trained diffusion.")
    parser.add_argument("--train_diffusion", type=str, default="yes", choices=["yes", "no"])
    parser.add_argument("--init_only_diff", type=str, default="yes", choices=["yes", "no"])

    parser.add_argument("--save_dir", type=str, default="saved_diff_models")
    parser.add_argument("--output_dir", type=str, default="../checkpoints/diffusion")
    parser.add_argument("--train_batch_size", type=int, default=512)
    parser.add_argument("--eval_batch_size", type=int, default=512)
    parser.add_argument("--num_train_steps", type=int, default=100000)
    parser.add_argument("--save_and_sample_every", type=int, default=10000)
    parser.add_argument("--num_samples", type=int, default=None)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, default=5e-4)
    parser.add_argument("--clip_grad_norm", type=float, default=1.0)
    parser.add_argument("--lr_schedule", type=str, default="cosine")
    parser.add_argument("--lr_warmup_steps", type=int, default=1000)
    parser.add_argument("--adam_beta1", type=float, default=0.9)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
    parser.add_argument("--ema_decay", type=float, default=0.9999)
    parser.add_argument("--ema_update_every", type=int, default=1)
    parser.add_argument("--objective", type=str, default="pred_x0", choices=["pred_noise", "pred_x0", "pred_v"])
    parser.add_argument("--loss_type", type=str, default="l2", choices=["l1", "l2", "smooth_l1"])
    parser.add_argument("--train_schedule", type=str, default="cosine", choices=["beta_linear", "simple_linear", "cosine", 'sigmoid'])
    parser.add_argument("--sampling_schedule", type=str, default=None, choices=["beta_linear", "cosine", "simple_linear", None])
    parser.add_argument("--scale", type=float, default=1.0)
    parser.add_argument("--sampling_timesteps", type=int, default=500)
    parser.add_argument("--max_seq_len", type=int, default=16)
    parser.add_argument("--self_condition", action="store_true", default=False)
    parser.add_argument("--train_prob_self_cond", type=float, default=0.5)
    parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim", "dpmpp"])
    parser.add_argument("--seq2seq_context_dim", type=int, default=64)
    parser.add_argument("--tx_dim", type=int, default=64)
    parser.add_argument("--tx_depth", type=int, default=6)
    parser.add_argument("--scale_shift", action="store_true", default=False)
    parser.add_argument("--disable_dropout", action="store_true", default=False)
    parser.add_argument("--seq2seq_unconditional_prob", type=float, default=0.1)
    parser.add_argument("--amp", action="store_true", default=False)
    parser.add_argument("--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"])

    # Distributed training parameters
    parser.add_argument("--distributed", action="store_true", help="Flag for enabling distributed training")
    parser.add_argument("--rank", type=int, default=0, help="Rank of the current process for distributed training")
    parser.add_argument("--world_size", type=int, default=1, help="Total number of processes for distributed training")
    parser.add_argument("--master_addr", type=str, default="localhost", help="Address of the master node for distributed training")
    parser.add_argument("--master_port", type=str, default="12355", help="Port of the master node for distributed training")
    parser.add_argument("--local_rank", type=int, default=0, help="Local rank passed by torch.distributed.launch for distributed training")

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    if args.distributed and torch.cuda.device_count() > 1:
        print("Activating distributed training!")
        mp.spawn(run_distributed,
                 args=(args.world_size, args.distributed, args),
                 nprocs=args.world_size,
                 join=True)
    else:
        run_distributed(0, args.world_size, args.distributed, args)
