import os
import logging
import shutil

import hydra
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import BasePredictionWriter
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
import torch.distributed

from src.gnn_model import CoSTGNNLightningModule
from src.data import CoSTGNNDataModule, CoSTScoreDataModule

torch.multiprocessing.set_sharing_strategy('file_system')
logger = logging.getLogger(__name__)

@rank_zero_only
def info(msg):
    logger.info(msg)
    
@rank_zero_only
def error(msg):
    logger.error(msg)
    

class ScoreWriter(BasePredictionWriter):
    def __init__(self, 
                 write_path, 
                 write_interval):
        super().__init__(write_interval)
        self.write_path = write_path
        self.write_interval = write_interval
    
    def write_on_epoch_end(self, trainer, model, prediction, batch_indices):
        gathered = [None] * torch.distributed.get_world_size()
        torch.distributed.all_gather_object(gathered, prediction)
        torch.distributed.barrier()
        
        if trainer.is_global_zero:
            prediction = sum(gathered, [])
            query_index, score = zip(*prediction)
            query_index, score = map(lambda x: torch.cat(x, dim=0), [query_index, score])
            
            info(f"Save Score to {self.write_path}")
            torch.save({
                'query_index': query_index, 
                'score': score
            }, self.write_path)
            
        torch.distributed.barrier()
    

@hydra.main(config_path='../configs', config_name='configure', version_base=None)
def run(config: DictConfig) -> None:
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    with open_dict(config):
        config.output_dir = hydra_cfg['runtime']['output_dir']
        
    os.makedirs(os.path.join(config.output_dir, 'pretrain-gnn', 'checkpoint'), exist_ok=True)

    info(OmegaConf.to_yaml(config))
    info(f'Experiment dir: {config.output_dir}')
    
    pl.seed_everything(config.seed)
    
    checkpoint_monitor = ModelCheckpoint(
        dirpath=os.path.join(config.output_dir, 'pretrain-gnn', 'checkpoint'),
        filename='{epoch}-{step}',
        monitor='valid_mrr',
        mode='max',
        save_top_k=2,
        every_n_epochs=1,
    )

    trainer_parameters = {
        'accelerator': config.trainer.accelerator,
        'devices': config.trainer.devices,
        'default_root_dir': os.path.join(config.output_dir, 'pretrain-gnn'),
        'strategy': config.trainer.strategy,
        'precision': config.gnn.trainer.precision,
        'callbacks': [checkpoint_monitor],
        'max_epochs': config.gnn.pretrain.max_epochs,
        # 'num_sanity_val_steps': 0
    }
    if config.wandb.use:
        wandb_logger = WandbLogger(
            name=config.wandb.run_name + '_pretrain-plm',
            project=config.wandb.project_name,
            config=config,
            save_dir=os.path.join(config.output_dir, 'pretrain-gnn', 'wandb')
        )
        trainer_parameters["logger"] = wandb_logger
    trainer = pl.Trainer(**trainer_parameters)
    model = CoSTGNNLightningModule.from_config(config, pretrain=True)
    model.set_node_embedding_path(os.path.join(config.output_dir, 'whitening-pretrain-node-embedding.pt'))
    model.set_relation_embedding_path(os.path.join(config.output_dir, 'whitening-pretrain-relation-embedding.pt'))
    model.set_inductive_node_embedding_path(os.path.join(config.output_dir, 'inductive-whitening-pretrain-node-embedding.pt'))
    model.set_inductive_relation_embedding_path(os.path.join(config.output_dir, 'inductive-whitening-pretrain-relation-embedding.pt'))
    datamodule = CoSTGNNDataModule.from_config(config)
    trainer.fit(model, datamodule)
    
    info('Start Evaluation on Test Set.')
    
    trainer.test(model, datamodule, ckpt_path=checkpoint_monitor.best_model_path)
    
    shutil.copy(checkpoint_monitor.best_model_path, os.path.join(config.output_dir, 'pretrain-gnn.pt'))
    
    
if __name__ == '__main__':
    run()