import os
import sys

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint
)
from pytorch_lightning.loggers import WandbLogger

import config_files._config_train_gene_encoder as config
from data.datamodules import DataModuleDistributed
from models._utils import set_seed
from models._visiumformer_spatial import VisiumformerSpatial 

gpu_list = sys.argv[1:]
gpu_list = [int(gpu) for gpu in gpu_list]
print(f"Using GPUs: {gpu_list}")

if __name__ == '__main__':
    set_seed(42)
    pl.seed_everything(42)
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of GPUs available: {num_gpus}.")
    else:
        print("No GPU available.")
    config = config.sweep_config
    # set for model
    model = VisiumformerSpatial(dim_model=config['dim_model'], 
                        nheads=config['nheads'], 
                        dim_feedforward=config['dim_feedforward'], 
                        nlayers=config['nlayers'],
                        dropout=config['dropout'],
                        batch_first=config['batch_first'], 
                        masking_p=config['masking_p'], 
                        n_tokens=config['n_tokens'],
                        context_length=config['context_length'],
                        warmup=config['warmup'],
                        lr=config['lr'],
                        decay=config['decay'],
                        batch_size=config['batch_size'],
                        max_epochs=config['max_epochs'],
                        autoregressive=config['autoregressive'],
                        pool=config['pool'],
                        supervised_task=config['supervised_task'],
                        learnable_pe=config['learnable_pe'],
                        spatial_aware=config['spatial_aware'])
    # load pretrained model, just for spatial-aware gene encoder pretraining
    if config['pretrained_path'] is not None:
        print("Loading pretrained spot model from", config['pretrained_path'])
        checkpoint = torch.load(config['pretrained_path'], map_location='cpu')
        model_state_dict = model.state_dict()
        filtered_state_dict = {
            k: v for k, v in checkpoint['state_dict'].items()
            if k in model_state_dict and model_state_dict[k].shape == v.shape
        }
        missing_layer = model.load_state_dict(filtered_state_dict, strict=False)
        if missing_layer.missing_keys:
            print(f"Missing keys: {missing_layer.missing_keys}")
        else:
            print("Pretrained model loaded successfully.")
    
    # set for logger
    wandb_logger = WandbLogger(project='PR')
    
    # set for checkpoint
    dirpath = 'path/to/save/checkpoints'
    if config['spatial_aware']:
        dirpath = os.path.join(dirpath, 'w_spatial')
    else:
        dirpath = os.path.join(dirpath, 'wo_spatial')
    os.makedirs(dirpath, exist_ok=True)

    checkpoint_callback = ModelCheckpoint(monitor='train_loss', 
                                          mode='min', 
                                          every_n_epochs=1, 
                                          train_time_interval=None, 
                                          save_top_k=-1,
                                          dirpath=dirpath)
    
    # set for learning rate monitor
    lr_monitor = LearningRateMonitor(logging_interval='step')

    # set for trainer
    trainer = pl.Trainer(
                        logger=wandb_logger,
                        devices=gpu_list,
                        num_nodes=1, 
                        accelerator='gpu',
                        max_epochs=1,
                        log_every_n_steps=1,
                        check_val_every_n_epoch=1,
                        strategy="ddp_find_unused_parameters_true",
                        callbacks=[checkpoint_callback, lr_monitor],
                        precision='bf16-mixed',
                        gradient_clip_val=1,
                        num_sanity_val_steps=0,
                        accumulate_grad_batches=2)

    # set for datamodule
    if config['spatial_aware']:
        path = os.path.join(config['data_path'], 'w_spatial')
    else:
        path = os.path.join(config['data_path'], 'wo_spatial')

    columns = ['tokenized_gene']
    
    print(f"Using path {path}")
    
    module = DataModuleDistributed(path=path, 
                        columns=columns,
                        sub_sample_frac=1,
                        batch_size=config['batch_size'],
                        world_size=trainer.world_size,
                        task_name='pretrain',
                        num_workers=8)
    
    # set for retake training
    if config['pretrained_path'] is not None and config['retake_training']:
        print(f"Training model from checkpoint!")
        trainer.fit(model=model, datamodule=module, ckpt_path=config['pretrained_path'])
        
    # set for training from scratch
    print(f"Training model from scratch")
    trainer.fit(model=model, datamodule=module)
