import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import os
import argparse
from dataset.dataset import get_table_latent_dataloader, create_df_dict_from_dir, create_config_dict_from_dir
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tableLatent.perceive.encoderDecoders import LatentAutoEncoder
from tableLatent.perceive.vaeTrainer import VAETrainer
from tableLatent.feature.VAE import compute_loss

def run_distributed(rank, world_size, distributed, args):
    """
    Each process runs this function.
    """
    # Set up the model
    encoder_params = {
        "dim": args.encoding_dim,
        "dim_latent": args.aggregated_dim,
        "depth": args.model_depth,
        "dim_head": args.dim_head,
        "num_latents": args.num_latent,
        "max_seq_len": args.max_input_seq_len,
        "ff_mult": args.ff_mult
    }
    decoder_params = {
        "lm_emb": args.encoding_dim,
        "aggregated_dim": args.aggregated_dim,
        "depth": args.model_depth,
        "output_dim": args.output_dim,
        "num_decoder_dim": 1
    }

    model = LatentAutoEncoder(
        encoder_params=encoder_params,
        decoder_params=decoder_params
    )

    # Create dataloaders
    train_df_dict, test_df_dict = create_df_dict_from_dir(args.csv_dir, test_size=args.test_size)
    config_dict = create_config_dict_from_dir(args.config_dir)

    train_dataloader = get_table_latent_dataloader(
        df_dict=train_df_dict, config_dict=config_dict, return_label=True,
        fixed_batch=args.fixed_batches, batch_size=args.batch_size
    )

    test_dataloader = get_table_latent_dataloader(
        df_dict=test_df_dict, config_dict=config_dict, return_label=True,
        fixed_batch=True,  batch_size=args.batch_size,
        #transformers_dict=train_dataloader.dataset.transformers_dict,
    )

    # Optimizer and Scheduler
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=args.factor, patience=args.patience, verbose=True)

    # Loss function
    criterion = compute_loss

    # Trainer setup, handle distributed setup within the trainer
    trainer = VAETrainer(
        model=model,
        dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        device="cuda" if torch.cuda.is_available() else "cpu",
        distributed=distributed,
        rank=rank,
        world_size=world_size,
        kl_reg=args.kl_reg
    )

    # Load model checkpoint if specified
    if args.checkpoint_dir is not None:
        _,_ = trainer.load_checkpoint(args.checkpoint_dir)

    # Run training
    trainer.train(num_epochs=args.epochs)

    # Save model checkpoint
    trainer.save_checkpoint(args.save_path, encoder_params, decoder_params)

def parse_args():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(description="Distributed Contrastive Learning with Perceiver Model")

    # Model parameters
    parser.add_argument('--encoding_dim', type=int, default=768, help='Dimension of lm encoding')
    parser.add_argument('--aggregated_dim', type=int, default=64, help='Dimension of aggregated latent vector')
    parser.add_argument('--model_depth', type=int, default=4, help='Depth of the perceiver model')
    parser.add_argument('--dim_head', type=int, default=64, help='Dimension of the attention head')
    parser.add_argument('--num_latent', type=int, default=16, help='Number of latent embeddings')
    parser.add_argument('--max_input_seq_len', type=int, default=96, help='Max input sequence length')
    parser.add_argument('--ff_mult', type=int, default=4, help='Feedforward multiplier')
    parser.add_argument('--output_dim', type=int, default=128, help='Dimension of the decoded embedding.')

    # Training parameters
    parser.add_argument('--learning_rate', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=1024, help='Batch size')
    parser.add_argument('--patience', type=int, default=5, help='Patience for learning rate scheduler')
    parser.add_argument('--factor', type=float, default=0.95, help='Factor for learning rate reduction')

    parser.add_argument('--fixed_batches', action='store_true', help='Whether to use fixed batches')
    parser.add_argument('--save_interval', type=int, default=10, help='Save interval in epochs')

    # File paths
    parser.add_argument('--csv_dir', type=str, required=True, help='Directory containing CSV files')
    parser.add_argument('--config_dir', type=str, required=True, help='Directory containing config files')
    parser.add_argument('--checkpoint_dir', type=str, default=None, help='Directory for trained checkpoints')
    parser.add_argument('--save_path', type=str, required=True, help='Path to save the trained model')

    # Distributed training
    parser.add_argument('--world_size', type=int, default=1, help='Number of distributed processes')
    parser.add_argument('--distributed', action='store_true', help='Use distributed training if multiple GPUs are available')

    # Other parameters
    parser.add_argument('--test_size', type=float, default=0.2, help='Proportion of data to use for validation')
    parser.add_argument('--kl_reg', type=float, default=1e-4, help='KL regularization term')

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    if args.distributed and torch.cuda.device_count() > 1:
        print("Activating distributed training!")
        mp.spawn(run_distributed,
                 args=(args.world_size, args.distributed, args),
                 nprocs=args.world_size,
                 join=True)
    else:
        run_distributed(0, args.world_size, args.distributed, args)
