import argparse
from dataloader import DataLoaderFactory
import torch
from dataset import CausalTripletDataset
from torch.utils import data
from models import CausalDeltaEmbeddingModel, PatchWiseDeltaEmbeddingModel
import pytorch_lightning as pl
import signal
from threading import Timer
from pytorch_lightning.loggers import WandbLogger
from torchvision import transforms


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', type=int, default=2025, help='Random seed')
    parser.add_argument('--dataset', type=str, default='procthor', help='Dataset name', choices=['epickitchens', 'procthor'])
    parser.add_argument('--data_root', type=str, default='/home/datasets/procthor', help='Path to dataset')
    parser.add_argument('--ood', type=str, default='comp', help='Out-of-distribution setting', choices=['comp', 'noun'])
    parser.add_argument('--img_width', type=int, default=224, help='Width to resize images')
    parser.add_argument('--mask', action='store_true', default=False, help='Use image masks')
    parser.add_argument('--bbox', action='store_true', default=False, help='Use bounding boxes')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers')
    parser.add_argument('--prefetch_factor', type=int, default=3, help='Prefetch factor for data loading')
    parser.add_argument('--train_size', type=int, default=1000, help='Size of training set')
    parser.add_argument("--log", action="store_true", default=False)
    parser.add_argument("--wb_project", type=str, default="LocalTest")
    parser.add_argument("--wb_group", type=str, default="testing...")
    parser.add_argument("--model", type=str, default='CausalDeltaEmbeddingModel')
    parser.add_argument("--backbone", default='vit_base_patch16_224.dino', type=str)
    parser.add_argument("--no_rebalance", default=False, action='store_true')
    parser.add_argument("--top_k", default=2, type=int, help='Top K patches for delta selection in PatchWise Model')
    parser.add_argument("--epochs", default=20, type=int)
    parser.add_argument("--proj_dim", default=512, type=int)
    parser.add_argument("--alpha_contrast", default=1.0, type=float)
    parser.add_argument("--alpha_sparsity", default=1.0, type=float)
    
    
    args = parser.parse_args()
    return args


def timeout_handler(signum, frame):
    print("\nReached time limit of 2 hours and 55 minutes. Stopping training gracefully...")
    raise SystemExit(0)


def main(args): 
    signal.signal(signal.SIGALRM, timeout_handler)
    # signal.alarm(20)
    signal.alarm(2 * 60 * 60 + 55 * 60)  # Set the alarm for 2 hours and 55 minutes

    # Set random seed for reproducibility
    pl.seed_everything(args.seed)

    torch.set_float32_matmul_precision('medium')
    df_train, df_test, df_valid, df_ood, dict_verb_index, dict_noun_index = \
        DataLoaderFactory.get_data_loaders(dataset=args.dataset,
                          root=args.data_root,
                          ood=args.ood,
                          seed=args.seed,
                          train_size=args.train_size,
                          rebalance=not args.no_rebalance)

    print(f"#Train: {len(df_train)}")
    print(f"#Test: {len(df_test)}")
    print(f"#Valid: {len(df_valid)}")
    print(f"#OOD: {len(df_ood)}")


    transform = transforms.Compose([
            transforms.Resize((args.img_width, args.img_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            # transforms.Normalize(mean=[0.485, 0.456, 0.406],
            #                      std=[0.229, 0.224, 0.225])
        ])


    train_dataset = CausalTripletDataset(dataset=args.dataset, df=df_train,
                               foldername=args.data_root, 
                               dict_noun_index=dict_noun_index,
                               dict_verb_index=dict_verb_index,
                               single_image=False,
                               img_width=args.img_width, bbox=args.bbox,
                               transform=transform)
    test_dataset = CausalTripletDataset(dataset=args.dataset, df=df_test,
                               foldername=args.data_root, 
                               dict_noun_index=dict_noun_index,
                               dict_verb_index=dict_verb_index,
                               single_image=False,
                               img_width=args.img_width, bbox=args.bbox,
                               transform=transform)
    valid_dataset = CausalTripletDataset(dataset=args.dataset, df=df_valid,
                               foldername=args.data_root, 
                               dict_noun_index=dict_noun_index,
                               dict_verb_index=dict_verb_index,
                               single_image=False,
                               img_width=args.img_width, bbox=args.bbox,
                               transform=transform)
    ood_dataset = CausalTripletDataset(dataset=args.dataset, df=df_ood,
                               foldername=args.data_root, 
                               dict_noun_index=dict_noun_index,
                               dict_verb_index=dict_verb_index,
                               single_image=False,
                               img_width=args.img_width, bbox=args.bbox,
                               transform=transform)
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, persistent_workers=True, drop_last=True)
    test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, persistent_workers=True)
    val_loader = data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, persistent_workers=True)
    ood_loader = data.DataLoader(ood_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, persistent_workers=True)

    if args.model == 'CausalDeltaEmbeddingModel':
        model_class = CausalDeltaEmbeddingModel
        model = CausalDeltaEmbeddingModel(
            num_actions=len(dict_verb_index),
            alpha_contrast=args.alpha_contrast,
            alpha_sparsity=args.alpha_sparsity
        )
    elif args.model == 'PatchWiseDeltaEmbeddingModel':
        model_class = PatchWiseDeltaEmbeddingModel
        model = PatchWiseDeltaEmbeddingModel(
            num_actions=len(dict_verb_index),
            top_k=args.top_k,
            proj_dim=args.proj_dim,
            hidden_dim=args.hidden_dim,
            alpha_contrast=args.alpha_contrast,
            alpha_sparsity=args.alpha_sparsity
        )
    else:
        raise ValueError("Model is not supported")
    trainer_args = {}

    callbacks = model_class.get_callbacks(
        every_n_epochs=5, 
        action_index_to_name={v:k for k, v in dict_verb_index.items()},
        object_index_to_name={v:k for k, v in dict_noun_index.items()},
        )

    logger = None
    if args.log:
        logger = WandbLogger(project=args.wb_project, group=args.wb_group, name=args.model)
        logger.experiment.config.update(vars(args))
    
    trainer = pl.Trainer(default_root_dir=args.wb_project+'/'+args.wb_group,
                         accelerator="auto", 
                         strategy="auto",
                         devices="auto",
                         max_epochs=args.epochs,
                         callbacks=callbacks,
                         check_val_every_n_epoch=1,
                         gradient_clip_val=1.0,
                        #  max_time={"hours": 2, "minutes": 45},
                         log_every_n_steps=10,
                         logger=logger,
                         **trainer_args)

    try:
        trainer.fit(model, train_loader, val_dataloaders=[test_loader, ood_loader])
    except SystemExit as e:
        if e.code == 0:
            print("Training stopped successfully after timeout")
        raise


if __name__ == '__main__':
    args = parse_args()
    main(args)