import os
import sys

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint
)
from pytorch_lightning.loggers import WandbLogger

import config_files._config_train_stamp as config
from data.datamodules import DataModuleDistributed
from models._stamp import Stamp
from models._utils import set_seed

gpu_list = sys.argv[1:]
gpu_list = [int(gpu) for gpu in gpu_list]
print(f"Using GPUs: {gpu_list}")

if __name__ == "__main__":
    set_seed(42)
    pl.seed_everything(42)
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of GPUs available: {num_gpus}")
    else:
        print("No GPUs available.")
    clip_config = config.sweep_config
    visual_config = config.visual_config
    spot_config = config.spot_config
    
    model = Stamp(spot_config=spot_config,
                    visual_config=visual_config,
                    dim_output=clip_config['dim_output'],
                    temperature=clip_config['temperature'],
                    extract_layers=clip_config['extract_layers'],
                    function_layers=clip_config['function_layers'],
                    lr=clip_config['lr'],
                    warmup=clip_config['warmup'],
                    max_epochs=clip_config['max_epochs'],
                    pool=clip_config['pool'],
                    without_context=clip_config['without_context'],
                    margin=clip_config['margin'],
                    p=clip_config['p'],
                    eps=clip_config['eps'])


    if clip_config['pretrained_path'] is not None:
        print("Loading pretrained model")
        checkpoint = torch.load(clip_config['pretrained_path'], map_location='cpu')  # 仅加载 checkpoint
        model.load_state_dict(checkpoint['state_dict'])  # 加载模型权重

    wandb_logger = WandbLogger(project=f'PR')
    dirpath = f"path/to/checkpoint/save/dir/{visual_config['model_name']}/"
    os.makedirs(dirpath, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(monitor='train_loss',
                                          mode='min', 
                                          every_n_epochs=2, 
                                          train_time_interval=None, 
                                          save_top_k=-1,
                                          dirpath=dirpath)
    
    lr_monitor = LearningRateMonitor(logging_interval='step')

    trainer = pl.Trainer(logger=wandb_logger,
                        devices=gpu_list,
                        num_nodes=1,
                        accelerator='gpu',
                        max_epochs=30,
                        log_every_n_steps=1,
                        check_val_every_n_epoch=1,
                        strategy="ddp_find_unused_parameters_true",
                        callbacks=[checkpoint_callback, lr_monitor],
                        precision='bf16-mixed',
                        gradient_clip_val=1,
                        accumulate_grad_batches=2,
                        num_sanity_val_steps=1)
    
    path = clip_config['data_path']
    columns = ['images', 'ref', 'images_aug', 'tokenized_gene', 'batch_slide_id', 'batch_dataset_id', 'pos_label']
    
    print(f"Using path {path}")

    from torchvision import transforms
    image_processor = transforms.Compose(
                    [   
                        transforms.Resize(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                    ]
                        )
    ref_processor = transforms.Compose(
                    [   
                        transforms.Resize(16),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                    ]
                        )

    module = DataModuleDistributed(path=path, 
                        columns=columns,
                        batch_size=clip_config['batch_size'],
                        world_size=trainer.world_size,
                        image_processor=image_processor,
                        ref_processor=ref_processor,
                        task_name='align',
                        vision_model_name=visual_config['model_name'],
                        num_workers=4*trainer.world_size
                        )
    
    if clip_config['pretrained_path'] is not None and clip_config['retake_training']:
        print(f"Training model from checkpoint!")
        trainer.fit(model=model, datamodule=module, ckpt_path=clip_config['pretrained_path'])
        
    print(f"Training model from scratch")
    trainer.fit(model=model, datamodule=module)