import os
import json
import logging
import shutil

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import BasePredictionWriter
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
import torch.distributed

from src.lm_model import CoSTPLMLightningModule
from src.data import CoSTPLMDataModule, CoSTNodeTextDataModule, CoSTRelationTextDataModule

torch.multiprocessing.set_sharing_strategy('file_system')
logger = logging.getLogger(__name__)

@rank_zero_only
def info(msg):
    logger.info(msg)
    
@rank_zero_only
def error(msg):
    logger.error(msg)
    

class EmbeddingWriter(BasePredictionWriter):
    def __init__(self, 
                 write_path, 
                 write_interval, 
                 whitening_dim, 
                 is_node):
        super().__init__(write_interval)
        self.write_path = write_path
        self.write_interval = write_interval
        self.whitening_dim = whitening_dim
        self.is_node = is_node
        
    def _whitening(self, emb):
        emb = emb.numpy()
        
        mean = np.mean(emb, axis=0, keepdims=True)
        cov = np.cov(emb.T)
        u, s, vh = np.linalg.svd(cov)
        kernel, bias = np.dot(u, np.diag(s**0.5)), -mean
        kernel = kernel[:, :self.whitening_dim]
        emb_whitening = (emb + bias).dot(kernel)
        
        emb_whitening = torch.from_numpy(emb_whitening).float()
        emb_whitening = F.normalize(emb_whitening, p=2, dim=1)
        return emb_whitening
    
    def write_on_epoch_end(self, trainer, model, prediction, batch_indices):
        gathered = [None] * torch.distributed.get_world_size()
        torch.distributed.all_gather_object(gathered, prediction)
        torch.distributed.barrier()
        
        if trainer.is_global_zero:
            prediction = sum(gathered, [])
            data_split_index, embedding_index, embedding = zip(*prediction)
            data_split_index, embedding_index, embedding = map(lambda x: torch.cat([_x.cpu() for _x in x], dim=0), 
                                                               [data_split_index, embedding_index, embedding])
            
            info(f"Save {'Node' if self.is_node else 'Relation'} Embedding to {self.write_path}")
            mask = data_split_index == 0
            torch.save(embedding[mask][embedding_index[mask].argsort()], self.write_path)
            whitening_embedding = self._whitening(embedding[mask][embedding_index[mask].argsort()])
            whitening_write_path = os.path.join(os.path.dirname(self.write_path), f'whitening-{os.path.split(self.write_path)[-1]}')
            torch.save(whitening_embedding, whitening_write_path)
            
            if torch.any(data_split_index == 1):
                ind_write_path = os.path.join(os.path.dirname(self.write_path), f'inductive-{os.path.split(self.write_path)[-1]}')
                info(f"Save Inductive {'Node' if self.is_node else 'Relation'} Embedding to {ind_write_path}")
                mask = data_split_index == 1
                torch.save(embedding[mask][embedding_index[mask].argsort()], self.write_path)
                whitening_embedding = self._whitening(embedding[mask][embedding_index[mask].argsort()])
                whitening_write_path = os.path.join(os.path.dirname(self.write_path), f'inductive-whitening-{os.path.split(self.write_path)[-1]}')
                torch.save(whitening_embedding, whitening_write_path)
            
        torch.distributed.barrier()
    

class ScoreWriter(BasePredictionWriter):
    def __init__(self,
                 write_path, 
                 write_interval):
        super().__init__(write_interval)
        self.write_path = write_path
        self.write_interval = write_interval
    
    def write_on_epoch_end(self, trainer, model, prediction, batch_indices):
        gathered = [None] * torch.distributed.get_world_size()
        torch.distributed.all_gather_object(gathered, prediction)
        torch.distributed.barrier()
        
        if trainer.is_global_zero:
            prediction = sum(gathered, [])
            query_index, score = zip(*prediction)
            query_index, score = map(lambda x: torch.cat([_x.cpu() for _x in x], dim=0), 
                                     [query_index, score])
            
            info(f"Save Score to {self.write_path}")
            torch.save({
                'query_index': query_index, 
                'score': score
            }, self.write_path)
            
        torch.distributed.barrier()
    

@hydra.main(config_path='../configs', config_name='configure', version_base=None)
def run(config: DictConfig) -> None:
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    with open_dict(config):
        config.output_dir = hydra_cfg['runtime']['output_dir']
        
    os.makedirs(os.path.join(config.output_dir, 'pretrain-plm', 'checkpoint'), exist_ok=True)

    info(OmegaConf.to_yaml(config))
    info(f'Experiment dir: {config.output_dir}')
    
    pl.seed_everything(config.seed)
    
    checkpoint_monitor = ModelCheckpoint(
        dirpath=os.path.join(config.output_dir, 'pretrain-plm', 'checkpoint'),
        filename='{epoch}-{step}',
        monitor='valid_mrr',
        mode='max',
        save_top_k=2,
        every_n_epochs=1,
    )

    trainer_parameters = {
        'accelerator': config.trainer.accelerator,
        'devices': config.trainer.devices,
        'default_root_dir': os.path.join(config.output_dir, 'pretrain-plm'),
        'strategy': config.trainer.strategy,
        'precision': config.plm.trainer.precision,
        'callbacks': [checkpoint_monitor],
        'max_epochs': config.plm.pretrain.max_epochs,
    }
    if config.wandb.use:
        wandb_logger = WandbLogger(
            name=config.wandb.run_name + '_pretrain-plm',
            project=config.wandb.project_name,
            config=config,
            save_dir=os.path.join(config.output_dir, 'pretrain-plm', 'wandb')
        )
        trainer_parameters["logger"] = wandb_logger
    trainer = pl.Trainer(**trainer_parameters)
    model = CoSTPLMLightningModule.from_config(config, pretrain=True)
    datamodule = CoSTPLMDataModule.from_config(config)
    trainer.fit(model, datamodule)
    
    info('Complete PLM Pretraining.')
    info('Start Generating Embeddings.')
    
    model.predict_embedding(True)
    
    node_text_trainer_parameters = {
        'accelerator': config.trainer.accelerator,
        'devices': config.trainer.devices,
        'default_root_dir': os.path.join(config.output_dir, 'pretrain-plm'),
        'strategy': config.trainer.strategy,
        'precision': config.plm.trainer.precision,
        'num_sanity_val_steps': 0
    }
    embedding_writer_callback = EmbeddingWriter(
        write_path=os.path.join(config.output_dir, 'pretrain-node-embedding.pt'), 
        write_interval='epoch', 
        whitening_dim=config.gnn.model.hidden_dim, 
        is_node=True
    )
    node_text_trainer_parameters['callbacks'] = [embedding_writer_callback]
    node_text_trainer = pl.Trainer(**node_text_trainer_parameters)
    node_text_datamodule = CoSTNodeTextDataModule.from_config(config)
    node_text_trainer.predict(model, node_text_datamodule, checkpoint_monitor.best_model_path)
    
    relation_text_trainer_parameters = {
        'accelerator': config.trainer.accelerator,
        'devices': config.trainer.devices,
        'default_root_dir': os.path.join(config.output_dir, 'pretrain-plm'),
        'strategy': config.trainer.strategy,
        'precision': config.plm.trainer.precision,
        'num_sanity_val_steps': 0
    }
    embedding_writer_callback = EmbeddingWriter(
        write_path=os.path.join(config.output_dir, 'pretrain-relation-embedding.pt'), 
        write_interval='epoch', 
        whitening_dim=config.gnn.model.hidden_dim, 
        is_node=False
    )
    relation_text_trainer_parameters['callbacks'] = [embedding_writer_callback]
    relation_text_trainer = pl.Trainer(**relation_text_trainer_parameters)
    relation_text_datamodule = CoSTRelationTextDataModule.from_config(config)
    relation_text_trainer.predict(model, relation_text_datamodule, checkpoint_monitor.best_model_path)
    
    info('Complete Embeddings Generation.')
    info('Start Evaluation on Test Set.')
    
    model.set_node_embedding_path(os.path.join(config.output_dir, 'pretrain-node-embedding.pt'))
    model.set_relation_embedding_path(os.path.join(config.output_dir, 'pretrain-relation-embedding.pt'))
    model.set_inductive_node_embedding_path(os.path.join(config.output_dir, 'inductive-pretrain-node-embedding.pt'))
    model.set_inductive_relation_embedding_path(os.path.join(config.output_dir, 'inductive-pretrain-relation-embedding.pt'))
    trainer.test(model, datamodule, ckpt_path=checkpoint_monitor.best_model_path)
    
    shutil.copy(checkpoint_monitor.best_model_path, os.path.join(config.output_dir, 'pretrain-plm.pt'))
    

if __name__ == '__main__':
    run()