import pytorch_lightning as pl

# from vpr_model_pair import VPRModel
from vpr_model import VPRModel
from dataloaders.GSVCitiesDataloader import GSVCitiesDataModule
import torch
import os

torch.set_float32_matmul_precision('high')

if __name__ == '__main__':        
    datamodule = GSVCitiesDataModule(
        # batch_size=60, #SALAD
        # batch_size=100, #BoQ
        # batch_size=72, #DK
        # batch_size=128, #DK
        # batch_size=160, #DK
        batch_size=128, #DK
        # batch_size=72, #EDT
        img_per_place=4,
        min_img_per_place=4,
        shuffle_all=False, # shuffle all images or keep shuffling in-city only
        random_sample_from_each_place=True,
        image_size=(224, 224), # SALAD
        # image_size=(280, 280), # BoQ
        num_workers=10,
        show_data_stats=True,
        val_set_names=['pitts30k_val', 'pitts30k_test', 'msls_val', 'nordland'], # pitts30k_val, pitts30k_test, msls_val
    )
    
    model = VPRModel(
        #---- Encoder
        
        # # Dinov2
        # backbone_arch='dinov2_vitb14',
        # backbone_config={
        #     # 'num_trainable_blocks': 4, # SALAD
        #     'num_trainable_blocks': 2, # BoQ
        #     'return_token': True,
        #     'norm_layer': True,
        # },

        # DINOv2 with DDF-Adapter
        # backbone_arch='dinov2_vitb14_da',
        backbone_arch='dinov3_vitb16_da',
        backbone_config={
            'num_da_layers': 2,
            'hidden_dim': 48,
            'return_token': True,
            'norm_layer': True,
        },

        #---- Aggregator
        # agg_arch='SALAD',
        # agg_config={
        #     'num_channels': 768,
        #     'num_clusters': 64,
        #     'cluster_dim': 128,
        #     'token_dim': 256,
        # },

        # agg_arch='GeM',
        # agg_config={
        #     'p': 3,
        # },

        # agg_arch='NetVLAD',
        # agg_config={
        #     'num_channels': 768,
        #     # 'num_clusters': 64,
        #     # 'cluster_dim': 128,
        #     # 'token_dim': 256,
        # },

        # agg_arch='MixVPR',
        # agg_config={
        #     'in_channels': 768,
        #     'out_channels': 512,
        #     'in_h': 20,
        #     'in_w': 20,
        #     'mix_depth': 3,
        # },

        agg_arch='DA',
        agg_config={
            'in_channels': 768,
            # 'num_layers': 1,
            # 'num_queries': 16, # number of learnable queries
            # 'hyper_dim': 128, # dimension of the hypernetwork output
        },


        # lr = 6e-5, # SALAD
        lr = 2e-4, # BoQ
        optimizer='adamw',
        # weight_decay=9.5e-9, # SALAD
        weight_decay=0.001, # BoQ
        momentum=0.9,
        # lr_sched='linear', # SALAD
        lr_sched='multistep', # BoQ
        lr_sched_args = {
            # multistep # BoQ
            # 'warmup_steps' : 5290,
            # 'milestones' : [10, 20, 30],
            # 'gamma' : 0.1,

            # linear # SALAD
            # 'start_factor': 1,
            # 'end_factor': 0.2,
            # 'total_iters': 4000,

            # multistep # EDT
            'warmup_steps' : 0,
            # 'milestones' : [3, 6, 9, 12],
            'milestones' : [8, 16, 24, 32],
            'gamma' : 0.7,
            'start_factor': 1,
        },




        #----- Loss functions
        loss_name='MultiSimilarityLoss',
        miner_name='MultiSimilarityMiner', # example: TripletMarginMiner, MultiSimilarityMiner, PairMarginMiner
        # miner_margin=0.1, # SALAD:0.1
        miner_margin=0.05, # BoQ:0.05
        # miner_margin=0.1,
        faiss_gpu=False
    )


    experiment_name = 'test'

    base_dir = '/home/dkim/VPR/DA2VPR/VPR_Project'
    save_dir = f'{base_dir}/{experiment_name}' 

    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'wandb'), exist_ok=True)

    # model params saving using Pytorch Lightning
    # we save the best 3 models accoring to Recall@1 on pittsburg val
    checkpoint_cb = pl.callbacks.ModelCheckpoint(
        dirpath=f'{save_dir}/checkpoints',
        monitor='pitts30k_val/R1',
        filename=f'{model.encoder_arch}' + '_({epoch:02d})_R1[{pitts30k_val/R1:.4f}]_R5[{pitts30k_val/R5:.4f}]',
        auto_insert_metric_name=False,
        save_weights_only=True,
        save_top_k=6,
        save_last=False,
        mode='max'
    )

    early_stop_cb = pl.callbacks.EarlyStopping(
        monitor='pitts30k_val/R1',   
        patience=30,                   
        mode='max',                   
        min_delta=0.001,              
        verbose=True
    )

    # WandbLogger 설정
    wandb_logger = pl.loggers.WandbLogger(
        name=experiment_name,
        project='VPR_Project',
        save_dir=save_dir,
        log_model=True
    )

    #------------------
    # we instanciate a trainer
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=1, # we use 1 GPU
        default_root_dir=save_dir,
        num_nodes=1,
        num_sanity_val_steps=0, # runs a validation step before stating training
        precision='16-mixed', # we use half precision to reduce  memory usage
        # precision=32, # proxyanchor
        max_epochs=30, # SALAD:4, BoQ: max40
        check_val_every_n_epoch=1, # run validation every epoch
        callbacks=[checkpoint_cb, early_stop_cb], # we save the best model
        reload_dataloaders_every_n_epochs=1, # we reload the dataset to shuffle the order
        # logger=wandb_logger, # we log the results to wandb
        log_every_n_steps=20,
    )

    # we call the trainer, we give it the model and the datamodule
    trainer.fit(model=model, datamodule=datamodule)

    # python main.py 
