from math import log
import os

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
import torch
import wandb

from slot_attention.configs.get_config import compute_config_entries_set_envs_etc
from slot_attention.configs.get_config import read_cli_args_and_get_config
from slot_attention.const import MOD_FILE_PATHS
from slot_attention.data.data_coco import CocoDataModule

from slot_attention.data.data_mod_tetro import MOTetroDataModule
from slot_attention.data.data_mod_sprites import MOSpritesDataModule
from slot_attention.data.data_mod_clevr6 import MODClevrModule
from slot_attention.data.data_objectsroom import ObjectsRoomDataModule
from slot_attention.method import SlotAttentionMethod
from slot_attention.model.model_coco import CocoHopfieldModel
from slot_attention.model.model_hopf import HopfieldModel
from slot_attention.model.model_hopf_objectsroom import HopfieldModelObjectsroom
from slot_attention.model.model_hopf_tetro import HopfieldModelTetro
from slot_attention.model.model_hopf_sprites import HopfieldModelSprites
from slot_attention.model.model_slatn import SlotAttentionModel
from slot_attention.model.model_slatn_vit_encoder import SlotAttentionModelVitEncoder
from slot_attention.model.model_vit import VitModel
from slot_attention.model.model_utils import ImageLogCallback
from slot_attention.model.model_vito import VitoModel
def main():
    # read CLI args
    params = read_cli_args_and_get_config()
    run_train(params)

def run_train(params):
    params = compute_config_entries_set_envs_etc(params)
    
    # assert config
    assert params.num_slots > 1, "Must have at least 2 slots."

    if params.is_verbose:
        print(f"INFO: limiting the dataset to only images with `num_slots - 1` ({params.num_slots - 1}) objects.")
        if params.num_train_images:
            print(f"INFO: restricting the train dataset size to `num_train_images`: {params.num_train_images}")
        if params.num_val_images:
            print(f"INFO: restricting the validation dataset size to `num_val_images`: {params.num_val_images}")
    
    if params.use_hopfield:

        if params.dataset_name == "tetrominoes":
            model = HopfieldModelTetro(
                params=params,
                resolution=params.resolution,
                num_slots=params.num_slots,
                num_iterations=params.num_iterations,

                in_channels = params.in_channels,
                kernel_size = params.kernel_size,
                hidden_dims = params.hidden_dims,
                slot_size = params.slot_size,
                hidden_dims_query = params.hidden_dims_query,
                decoder_resolution = params.decoder_resolution,
                decoder_stride = params.decoder_stride,
                decoder_padding = params.decoder_padding,
                decoder_output_padding = params.decoder_output_padding,
                decoder_hidden_dims=params.decoder_hidden_dims,

                empty_cache=params.empty_cache,

                hopfield_steps_eps=params.hopfield_steps_eps,
                use_hopfield=params.use_hopfield,
                use_hopfield_norm_before=params.use_hopfield_norm_before,
                use_hopfield_project_before=params.use_hopfield_project_before,
                use_hopfield_actual=params.use_hopfield_actual,
                use_hopfield_qqq=params.use_hopfield_qqq,
                use_hopfield_reverse_softmax=params.use_hopfield_reverse_softmax,
                use_hopfield_only_qqq=params.use_hopfield_only_qqq,
                use_double_softmax=params.use_double_softmax,
                
                use_competetion=params.use_competetion,
                use_gumble_softmax=params.use_gumble_softmax,
                l1_loss=params.l1_loss,
                hopfield_beta=params.hopfield_beta,
                use_average_pool=params.use_average_pool,
                average_pool_size=params.average_pool_size,
                average_pool_stride=params.average_pool_stride,
                use_max_pool=params.use_max_pool,
                use_residual_path=params.use_residual_path,
                max_pool_size=params.max_pool_size,
                max_pool_stride=params.max_pool_stride
            )

        elif params.dataset_name == "multi_dsprites":
            
            model = HopfieldModelSprites(
                resolution=params.resolution,
                num_slots=params.num_slots,
                num_iterations=params.num_iterations,

                in_channels = params.in_channels,
                kernel_size = params.kernel_size,
                hidden_dims = params.hidden_dims,
                slot_size = params.slot_size,
                hidden_dims_query = params.hidden_dims_query,
                decoder_resolution = params.decoder_resolution,
                decoder_stride = params.decoder_stride,
                decoder_padding = params.decoder_padding,
                decoder_output_padding = params.decoder_output_padding,
                decoder_hidden_dims=params.decoder_hidden_dims,

                empty_cache=params.empty_cache,

                hopfield_steps_eps=params.hopfield_steps_eps,
                use_hopfield=params.use_hopfield,
                use_hopfield_norm_before=params.use_hopfield_norm_before,
                use_hopfield_project_before=params.use_hopfield_project_before,
                use_hopfield_actual=params.use_hopfield_actual,
                use_hopfield_qqq=params.use_hopfield_qqq,
                use_hopfield_reverse_softmax=params.use_hopfield_reverse_softmax,
                use_hopfield_only_qqq=params.use_hopfield_only_qqq,
                use_double_softmax=params.use_double_softmax,
                
                use_competetion=params.use_competetion,
                use_gumble_softmax=params.use_gumble_softmax,
                l1_loss=params.l1_loss,
                hopfield_beta=params.hopfield_beta,
                use_average_pool=params.use_average_pool,
                average_pool_size=params.average_pool_size,
                average_pool_stride=params.average_pool_stride,
                use_max_pool=params.use_max_pool,
                use_residual_path=params.use_residual_path,
                max_pool_size=params.max_pool_size,
                max_pool_stride=params.max_pool_stride
            )

        elif params.dataset_name == "objectsroom":
            
            model = HopfieldModelObjectsroom(
                params=params,
                resolution=params.resolution,
                num_slots=params.num_slots,
                num_iterations=params.num_iterations,

                in_channels = params.in_channels,
                kernel_size = params.kernel_size,
                hidden_dims = params.hidden_dims,
                slot_size = params.slot_size,
                hidden_dims_query = params.hidden_dims_query,
                decoder_resolution = params.decoder_resolution,
                decoder_stride = params.decoder_stride,
                decoder_padding = params.decoder_padding,
                decoder_output_padding = params.decoder_output_padding,
                decoder_hidden_dims=params.decoder_hidden_dims,

                empty_cache=params.empty_cache,

                hopfield_steps_eps=params.hopfield_steps_eps,
                use_hopfield=params.use_hopfield,
                use_hopfield_norm_before=params.use_hopfield_norm_before,
                use_hopfield_project_before=params.use_hopfield_project_before,
                use_hopfield_actual=params.use_hopfield_actual,
                use_hopfield_qqq=params.use_hopfield_qqq,
                use_hopfield_reverse_softmax=params.use_hopfield_reverse_softmax,
                use_hopfield_only_qqq=params.use_hopfield_only_qqq,
                use_double_softmax=params.use_double_softmax,
                
                use_competetion=params.use_competetion,
                use_gumble_softmax=params.use_gumble_softmax,
                l1_loss=params.l1_loss,
                hopfield_beta=params.hopfield_beta,
                use_average_pool=params.use_average_pool,
                average_pool_size=params.average_pool_size,
                average_pool_stride=params.average_pool_stride,
                use_max_pool=params.use_max_pool,
                use_residual_path=params.use_residual_path,
                max_pool_size=params.max_pool_size,
                max_pool_stride=params.max_pool_stride
            )


        elif params.dataset_name == "clevr6":
            model = HopfieldModel(
                resolution=params.resolution,
                num_slots=params.num_slots,
                num_iterations=params.num_iterations,

                in_channels = params.in_channels,
                kernel_size = params.kernel_size,
                hidden_dims = params.hidden_dims,
                slot_size = params.slot_size,
                hidden_dims_query = params.hidden_dims_query,
                decoder_resolution = params.decoder_resolution,
                decoder_stride = params.decoder_stride,
                decoder_padding = params.decoder_padding,
                decoder_output_padding = params.decoder_output_padding,
                decoder_hidden_dims=params.decoder_hidden_dims,

                empty_cache=params.empty_cache,

                hopfield_steps_eps=params.hopfield_steps_eps,
                use_hopfield=params.use_hopfield,
                use_hopfield_norm_before=params.use_hopfield_norm_before,
                use_hopfield_project_before=params.use_hopfield_project_before,
                use_hopfield_actual=params.use_hopfield_actual,
                use_hopfield_qqq=params.use_hopfield_qqq,
                use_hopfield_reverse_softmax=params.use_hopfield_reverse_softmax,
                use_hopfield_only_qqq=params.use_hopfield_only_qqq,
                use_double_softmax=params.use_double_softmax,
                
                use_competetion=params.use_competetion,
                use_gumble_softmax=params.use_gumble_softmax,
                l1_loss=params.l1_loss,
                hopfield_beta=params.hopfield_beta,
                use_average_pool=params.use_average_pool,
                average_pool_size=params.average_pool_size,
                average_pool_stride=params.average_pool_stride,
                use_max_pool=params.use_max_pool,
                use_residual_path=params.use_residual_path,
                max_pool_size=params.max_pool_size,
                max_pool_stride=params.max_pool_stride
            )
        
        elif params.dataset_name == "coco":
            model = CocoHopfieldModel(
                resolution=params.resolution,
                num_slots=params.num_slots,
                num_iterations=params.num_iterations,

                in_channels = params.in_channels,
                kernel_size = params.kernel_size,
                hidden_dims = params.hidden_dims,
                slot_size = params.slot_size,
                hidden_dims_query = params.hidden_dims_query,
                decoder_resolution = params.decoder_resolution,
                decoder_stride = params.decoder_stride,
                decoder_padding = params.decoder_padding,
                decoder_output_padding = params.decoder_output_padding,
                decoder_hidden_dims=params.decoder_hidden_dims,

                empty_cache=params.empty_cache,

                hopfield_steps_eps=params.hopfield_steps_eps,
                use_hopfield=params.use_hopfield,
                use_hopfield_norm_before=params.use_hopfield_norm_before,
                use_hopfield_project_before=params.use_hopfield_project_before,
                use_hopfield_actual=params.use_hopfield_actual,
                use_hopfield_qqq=params.use_hopfield_qqq,
                use_hopfield_reverse_softmax=params.use_hopfield_reverse_softmax,
                use_hopfield_only_qqq=params.use_hopfield_only_qqq,
                use_double_softmax=params.use_double_softmax,
                
                use_competetion=params.use_competetion,
                use_gumble_softmax=params.use_gumble_softmax,
                l1_loss=params.l1_loss,
                hopfield_beta=params.hopfield_beta,
                use_average_pool=params.use_average_pool,
                average_pool_size=params.average_pool_size,
                average_pool_stride=params.average_pool_stride,
                use_max_pool=params.use_max_pool,
                use_residual_path=params.use_residual_path,
                max_pool_size=params.max_pool_size,
                max_pool_stride=params.max_pool_stride
            )
        
        else:
            raise NotImplementedError
    
    elif params.use_vit:
        
        if params.use_vit_enc_only:
            
            model = SlotAttentionModelVitEncoder(
                params=params,
                resolution=params.resolution,
                num_slots=params.num_slots,
                num_iterations=params.num_iterations,
                in_channels=params.in_channels,
                kernel_size=params.kernel_size,
                slot_size=params.slot_size,
                hidden_dims=params.hidden_dims,
                decoder_padding=params.decoder_padding,
                decoder_resolution=params.decoder_resolution,
                decoder_hidden_dims=params.decoder_hidden_dims,
                decoder_stride=params.decoder_stride,
                decoder_output_padding=params.decoder_output_padding,
                empty_cache=params.empty_cache,
            )   
        else:
            
            model = VitModel(
                params=params,
            )   
    
    elif params.use_vito:
        
        model = VitoModel(params=params)
    
    else:
        model = SlotAttentionModel(
            params=params,
            resolution=params.resolution,
            num_slots=params.num_slots,
            num_iterations=params.num_iterations,
            in_channels=params.in_channels,
            kernel_size=params.kernel_size,
            slot_size=params.slot_size,
            hidden_dims=params.hidden_dims,
            decoder_padding=params.decoder_padding,
            decoder_resolution=params.decoder_resolution,
            decoder_hidden_dims=params.decoder_hidden_dims,
            decoder_stride=params.decoder_stride,
            decoder_output_padding=params.decoder_output_padding,
            empty_cache=params.empty_cache,
        )   
    
    datamod = get_datamod(params)
    print(f"Training set size (images must have {params.num_slots - 1} objects):", len(datamod.train_dataset))
    
    if params.use_compile:
        print(f"INFO: use torch.compile")
        model = torch.compile(model=model)
    
    method = SlotAttentionMethod(model=model, datamodule=datamod, params=params)

    logger = WandbLogger(project=params.wandb_project_name, 
                         name=params.wandb_logger_name, 
                         config=params, 
                         entity=params.wandb_entity_name,
                         save_dir=f'/system/user/publicwork/{os.environ["USER"]}/wandb',
                         log_model=False,
                         )
    
    # if params.log_gradients:
    #     logger.watch(model, log_freq=1)
    #     print(f'logging gradients since log_gradients is: {params.log_gradients}')
    
    trainer = Trainer(
        default_root_dir="/system/user/publicwork/{os.environ['USER']}/ptl",
        logger=logger if params.is_logger_enabled else False,
        strategy="ddp" if params.n_gpus > 1 else 'auto',
        accelerator="gpu",
        num_sanity_val_steps=params.num_sanity_val_steps,
        devices=params.n_gpus,
        max_epochs=params.max_epochs,
        log_every_n_steps=params.log_every_n_steps,
        callbacks=[LearningRateMonitor("step"), ImageLogCallback(),] if params.is_logger_enabled else [],
        gradient_clip_val=params.gradient_clip_val,
    )
    trainer.fit(method, datamod)

def get_datamod(params):
    
     # for multi-objects dataset (MOD)
    if params.dataset_name == "tetrominoes":
    
        datamod = MOTetroDataModule(
            params=params,
            train_batch_size=params.batch_size,
            val_batch_size=params.val_batch_size,
            num_workers=params.num_workers,
            train_val_perc=params.train_val_percent,
        )
    
    elif params.dataset_name == "clevr6":
        data_root = MOD_FILE_PATHS[params.dataset_name][params.mode]

        datamod = MODClevrModule(
            data_root=data_root,
            train_batch_size=params.batch_size,
            val_batch_size=params.val_batch_size,
            num_workers=params.num_workers,
            train_val_perc=params.train_val_percent,
        )
    
    elif params.dataset_name == "objectsroom":
        
        datamod = ObjectsRoomDataModule(
            train_batch_size=params.batch_size,
            val_batch_size=params.val_batch_size,
            num_workers=params.num_workers,
            train_val_perc=params.train_val_percent,
        )
        
    elif params.dataset_name == "multi_dsprites":
        data_root = MOD_FILE_PATHS[params.dataset_name][params.mode]
    
        datamod = MOSpritesDataModule(
            data_root=data_root,
            train_batch_size=params.batch_size,
            val_batch_size=params.val_batch_size,
            num_workers=params.num_workers,
            train_val_perc=params.train_val_percent,
        )    
        
    elif params.dataset_name == "coco":
    
        datamod = CocoDataModule(
            params=params,
            train_batch_size=params.batch_size,
            val_batch_size=params.val_batch_size,
            num_workers=params.num_workers,
        )    
    
    else:
        raise ValueError(f"Unknown dataset: {params.dataset_name}")
    
    return datamod

if __name__ == "__main__":
    main()
