import argparse
import os
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning import Trainer, seed_everything

from src.datamodules.cars import CarsDataModule
from src.datamodules.pets import PetsDataModule
from src.datamodules.imagenette import ImagenetteDataModule
from src.datamodules.imagenet import ImagenetDataModule
from src.models.patch_importance import PatchImportance
from src.models.patch_importance_cnn import PatchImportanceCNN
from src.models.resnet_importance import ResNetImportanceModel

# Argument parser
parser = argparse.ArgumentParser(description="Train Patch Importance Model")

parser.add_argument('--dataset', type=str, default='pets', choices=['cars', 'pets', 'imagenette', 'imagenet'],
                    help='Dataset to use')
parser.add_argument('--model_type', type=str, default='ba', choices=['ba', 'fa', 'ra'],
                    help='Model type to use')
parser.add_argument('--embedding_size', type=int, default=32,
                    help='Embedding size')
parser.add_argument('--patch_size', type=int, default=16,
                    help='Patch size')
parser.add_argument('--checkpoint_path', type=str, default='',
                    help='Path to the model checkpoint')
parser.add_argument('--threshold', type=float, default=0.1,
                    help='Threshold for RA')

args = parser.parse_args()

# Seed
seed_everything(111)

# Select datamodule
if args.dataset == 'cars':
    datamodule = CarsDataModule(batch_size=128)
elif args.dataset == 'pets':
    datamodule = PetsDataModule(batch_size=128, target_type='category')
elif args.dataset == 'imagenette':
    datamodule = ImagenetteDataModule(batch_size=128)
elif args.dataset == 'imagenet':
    datamodule = ImagenetDataModule(data_dir=os.environ['IMAGENET_PATH'], batch_size=128)
else:
    raise ValueError("Invalid dataset")

# Load model
if args.model_type in ('fa', 'ra') and not os.path.exists(args.checkpoint_path):
    raise ValueError("Checkpoint path must be provided for fine-tuned or reference architectures")

if args.model_type == 'ba':
    model = PatchImportance(embedding_size=args.embedding_size, patch_size=args.patch_size, lr=0.001,
                            importance_method='cnn_image', aggregate_operation='sum',
                            num_classes=datamodule.num_classes, optimizer='adam', use_importance=True)
elif args.model_type == 'fa':
    model = PatchImportanceCNN.load_from_checkpoint(args.checkpoint_path, strict=False)
    model.freeze_weights(freeze_importance=True)
elif args.model_type == 'ra':
    model = ResNetImportanceModel(num_classes=datamodule.num_classes,
                                  lr=0.1 if args.dataset == 'imagenet' else 0.001,
                                  path=args.checkpoint_path, resnet='resnet50', threshold=args.threshold,
                                  optimizer='sgd' if args.dataset == 'imagenet' else 'adam')
else:
    raise ValueError("Invalid model type")


filename = f'checkpoint_{args.model_type}_{args.embedding_size}_{args.patch_size}'

checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val/acc",
    mode="max",
    dirpath=f'checkpoints/{args.dataset}/',
    filename=filename,
)

# Train and test
trainer = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback],
)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule, ckpt_path='best')
