import argparse
import os
import sys

import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint
)
from pytorch_lightning.loggers import WandbLogger
import torch
from torchvision import transforms

import config_files._config_finetune_stamp as config
from config_files._constants import mapping as project_mapping
from data.datamodules import DownstreamDataModuleDistributed
from models._finetune_stamp import FinetuneStamp
from models._utils import set_seed


gpu_list = sys.argv[1:]
gpu_list = [int(gpu) for gpu in gpu_list]
print(f"Using GPUs: {gpu_list}")

argparser = argparse.ArgumentParser()
argparser.add_argument('--project', type=str, default='PSC')
argparser.add_argument('--sample', type=str, default='')
argparser.add_argument('--data_path', type=str, default='path/to/downstream/data')
args = argparser.parse_args()

sample_list = list(project_mapping[args.project])
if args.sample != '':
    sample_list = [sample for sample in sample_list if sample != args.sample]

if __name__ == "__main__":
    set_seed(42)
    pl.seed_everything(42)
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of GPUs available: {num_gpus}")
    else:
        print("No GPUs available.")
    clip_config = config.sweep_config
    visual_config = config.visual_config
    spot_config = config.spot_config
    
    model = FinetuneStamp(spot_config=spot_config,
                    visual_config=visual_config,
                    dim_output=clip_config['dim_output'],
                    temperature=clip_config['temperature'],
                    extract_layers=clip_config['extract_layers'],
                    function_layers=clip_config['function_layers'],
                    lr=clip_config['lr'],
                    warmup=5,
                    max_epochs=clip_config['max_epochs'],
                    pool=clip_config['pool'],
                    without_context=clip_config['without_context'],
                    margin=clip_config['margin'],
                    p=clip_config['p'],
                    eps=clip_config['eps'])

    if clip_config['pretrained_path'] is not None:
        print("Loading pretrained model")
        checkpoint = torch.load(clip_config['pretrained_path'], map_location='cpu')
        missing_layers = model.load_state_dict(checkpoint['state_dict'], strict=False)
        if missing_layers.missing_keys:  
            print("\nMissing keys (layers that were not loaded):")  
            for key in missing_layers.missing_keys:  
                print(key) 
    else:
        raise ValueError("No pretrained model found!")
    
    wandb_dir = './wandb/finetune'
    os.makedirs(wandb_dir, exist_ok=True)
    wandb_logger = WandbLogger(project=f'PR_finetune', save_dir=wandb_dir)
    dirpath = os.path.join('path/to/downstream/checkpoint/dir',visual_config['model_name'], args.project)
    os.makedirs(dirpath, exist_ok=True)
    if not args.sample:
        save_file_name = "all_samples"
    else:
        save_file_name = f"{args.sample}"

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        dirpath=dirpath,
        filename=save_file_name,
    )
    early_stop_callback = EarlyStopping(
        monitor='val_loss', 
        mode='min',
        patience=5,
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='step')

    trainer = pl.Trainer(logger=None,
                        devices=gpu_list,
                        num_nodes=1,
                        accelerator='gpu',
                        max_epochs=20,
                        log_every_n_steps=1,
                        check_val_every_n_epoch=1,
                        strategy="ddp_find_unused_parameters_true",
                        callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
                        precision='bf16-mixed',
                        gradient_clip_val=1,
                        accumulate_grad_batches=1,
                        num_sanity_val_steps=0)
    
    path = os.path.join(args.data_path, args.project)
    columns = ['images', 'tokenized_gene']
    
    print(f"Using path {path}")

    image_processor = transforms.Compose(
            [   
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
                )

    module = DownstreamDataModuleDistributed(path=path, 
                        columns=columns,
                        batch_size=clip_config['batch_size'],
                        image_processor=image_processor,
                        vision_model_name=visual_config['model_name'],
                        num_workers=4,
                        sample_list=sample_list,
                        test_sample=args.sample,
                        )
    
    if clip_config['pretrained_path'] is not None and clip_config['retake_training']:
        print(f"Training model from checkpoint!")
        trainer.fit(model=model, datamodule=module, ckpt_path=clip_config['pretrained_path'])
        
    print(f"Training model from scratch")
    trainer.fit(model=model, datamodule=module)