import os
import json
import logging

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import BasePredictionWriter
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
import torch.distributed

from src.lm_model import CoSTPLMLightningModule, PLMReasonModel
from src.gnn_model import CoSTGNNLightningModule, GNNReasonModel
from src.data import CoSTGNNDataModule, CoSTPLMDataModule, CoSTNodeTextDataModule, CoSTRelationTextDataModule, CoSTPseudoFactDataModule


torch.multiprocessing.set_sharing_strategy('file_system')
logger = logging.getLogger(__name__)

@rank_zero_only
def info(msg):
    logger.info(msg)
    
@rank_zero_only
def error(msg):
    logger.error(msg)


class EmbeddingWriter(BasePredictionWriter):
    def __init__(self, 
                 write_path, 
                 write_interval, 
                 whitening_dim, 
                 is_node):
        super().__init__(write_interval)
        self.write_path = write_path
        self.write_interval = write_interval
        self.whitening_dim = whitening_dim
        self.is_node = is_node
        
    def _whitening(self, emb):
        emb = emb.numpy()
        
        mean = np.mean(emb, axis=0, keepdims=True)
        cov = np.cov(emb.T)
        u, s, vh = np.linalg.svd(cov)
        kernel, bias = np.dot(u, np.diag(s**0.5)), -mean
        kernel = kernel[:, :self.whitening_dim]
        emb_whitening = (emb + bias).dot(kernel)
        
        emb_whitening = torch.from_numpy(emb_whitening).float()
        emb_whitening = F.normalize(emb_whitening, p=2, dim=1)
        return emb_whitening
    
    def write_on_epoch_end(self, trainer, model, prediction, batch_indices):
        gathered = [None] * torch.distributed.get_world_size()
        torch.distributed.all_gather_object(gathered, prediction)
        torch.distributed.barrier()
        
        if trainer.is_global_zero:
            prediction = sum(gathered, [])
            data_split_index, embedding_index, embedding = zip(*prediction)
            data_split_index, embedding_index, embedding = map(lambda x: torch.cat([_x.cpu() for _x in x], dim=0), 
                                                               [data_split_index, embedding_index, embedding])
            
            info(f"Save {'Node' if self.is_node else 'Relation'} Embedding to {self.write_path}")
            mask = data_split_index == 0
            torch.save(embedding[mask][embedding_index[mask].argsort()], self.write_path)
            whitening_embedding = self._whitening(embedding[mask][embedding_index[mask].argsort()])
            whitening_write_path = os.path.join(os.path.dirname(self.write_path), f'whitening-{os.path.split(self.write_path)[-1]}')
            torch.save(whitening_embedding, whitening_write_path)
            
            if torch.any(data_split_index == 1):
                ind_write_path = os.path.join(os.path.dirname(self.write_path), f'inductive-{os.path.split(self.write_path)[-1]}')
                info(f"Save Inductive {'Node' if self.is_node else 'Relation'} Embedding to {ind_write_path}")
                mask = data_split_index == 1
                torch.save(embedding[mask][embedding_index[mask].argsort()], self.write_path)
                whitening_embedding = self._whitening(embedding[mask][embedding_index[mask].argsort()])
                whitening_write_path = os.path.join(os.path.dirname(self.write_path), f'inductive-whitening-{os.path.split(self.write_path)[-1]}')
                torch.save(whitening_embedding, whitening_write_path)
            
        torch.distributed.barrier()
        

class PseudoFactWriter(BasePredictionWriter):
    def __init__(self,
                 write_path, 
                 write_interval):
        super().__init__(write_interval)
        self.write_path = write_path
        self.write_interval = write_interval
    
    def write_on_epoch_end(self, trainer, model, prediction, batch_indices):
        gathered = [None] * torch.distributed.get_world_size()
        torch.distributed.all_gather_object(gathered, prediction)
        torch.distributed.barrier()
        
        if trainer.is_global_zero:
            pseudo_fact = sum(gathered, [])
            # pseudo_fact = zip(*prediction)
            pseudo_fact = torch.cat([x.cpu() for x in pseudo_fact], dim=0)
            
            info(f"Save Pseudo Fact {pseudo_fact.shape} to {self.write_path}")
            torch.save(pseudo_fact, self.write_path)
            
        torch.distributed.barrier()
        

@hydra.main(config_path='../configs', config_name='configure', version_base=None)
def run(config: DictConfig) -> None:
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    with open_dict(config):
        config.output_dir = hydra_cfg['runtime']['output_dir']
        
    os.makedirs(os.path.join(config.output_dir, 'cotrain', 'plm-checkpoint'), exist_ok=True)
    os.makedirs(os.path.join(config.output_dir, 'cotrain', 'gnn-checkpoint'), exist_ok=True)

    info(OmegaConf.to_yaml(config))
    info(f'Experiment dir: {config.output_dir}')
    
    pl.seed_everything(config.seed)
    
    plm_model = CoSTPLMLightningModule.load_from_checkpoint(os.path.join(config.output_dir, 'pretrain-plm.pt'), 
                                                            config=config, model=PLMReasonModel.from_config(config), pretrain=False)
    plm_model.set_node_embedding_path(os.path.join(config.output_dir, 'whitening-pretrain-node-embedding.pt'))
    plm_model.set_relation_embedding_path(os.path.join(config.output_dir, 'whitening-pretrain-relation-embedding.pt'))
    plm_model.set_inductive_node_embedding_path(os.path.join(config.output_dir, 'inductive-whitening-pretrain-node-embedding.pt'))
    plm_model.set_inductive_relation_embedding_path(os.path.join(config.output_dir, 'inductive-whitening-pretrain-relation-embedding.pt'))
    plm_datamodule = CoSTPLMDataModule.from_config(config, cotrain=True)

    
    gnn_model = CoSTGNNLightningModule.load_from_checkpoint(os.path.join(config.output_dir, 'pretrain-gnn.pt'), 
                                                            config=config, model=GNNReasonModel.from_config(config), pretrain=False)
    gnn_model.set_node_embedding_path(os.path.join(config.output_dir, 'whitening-pretrain-node-embedding.pt'))
    gnn_model.set_relation_embedding_path(os.path.join(config.output_dir, 'whitening-pretrain-relation-embedding.pt'))
    gnn_model.set_inductive_node_embedding_path(os.path.join(config.output_dir, 'inductive-whitening-pretrain-node-embedding.pt'))
    gnn_model.set_inductive_relation_embedding_path(os.path.join(config.output_dir, 'inductive-whitening-pretrain-relation-embedding.pt'))
    gnn_datamodule = CoSTGNNDataModule.from_config(config, cotrain=True)

    for i in range(config.cotrain.max_epochs):
        gnn_pseudo_trainer_parameters = {
            'accelerator': config.trainer.accelerator,
            'devices': config.trainer.devices,
            'default_root_dir': os.path.join(config.output_dir, f'cotrain-gnn-predict-{i}'),
            'strategy': config.trainer.strategy,
            'precision': 32,
            'num_sanity_val_steps': 0
        }
        gnn_pseudo_writer_callback = PseudoFactWriter(
            write_path=os.path.join(config.output_dir, f'gnn-pseudo-fact-{i}.pt'), 
            write_interval='epoch',
        )
        gnn_pseudo_trainer_parameters['callbacks'] = [gnn_pseudo_writer_callback]
        gnn_pseudo_trainer = pl.Trainer(**gnn_pseudo_trainer_parameters)
        gnn_pseudo_datamodule = CoSTPseudoFactDataModule.from_config(config)
        gnn_pseudo_trainer.predict(gnn_model, gnn_pseudo_datamodule)        
        
        plm_model.set_pseudo_fact_path(os.path.join(config.output_dir, f'gnn-pseudo-fact-{i}.pt'))
        plm_checkpoint_monitor = ModelCheckpoint(
            dirpath=os.path.join(config.output_dir, f'cotrain-plm-{i}', 'checkpoint'),
            filename='{epoch}-{step}',
            monitor='valid_mrr',
            mode='max',
            save_top_k=1,
            every_n_epochs=1,
        )
        plm_trainer_parameters = {
            'accelerator': config.trainer.accelerator,
            'devices': config.trainer.devices,
            'default_root_dir': os.path.join(config.output_dir, f'cotrain-plm-{i}'),
            'strategy': config.trainer.strategy,
            'precision': '16-mixed',
            'callbacks': [plm_checkpoint_monitor],
            'max_epochs': 3,
            'num_sanity_val_steps': 0
        }
        if config.wandb.use:
            wandb_logger = WandbLogger(
                name=config.wandb.run_name + f'_cotrain-plm-{i}',
                project=config.wandb.project_name,
                config=config,
                save_dir=os.path.join(config.output_dir, f'cotrain-plm-{i}', 'wandb')
            )
            plm_trainer_parameters["logger"] = wandb_logger
        plm_trainer = pl.Trainer(**plm_trainer_parameters)
        plm_trainer.fit(plm_model, plm_datamodule)
        
        plm_model.predict_embedding(True)
        node_text_trainer_parameters = {
            'accelerator': config.trainer.accelerator,
            'devices': config.trainer.devices,
            'default_root_dir': os.path.join(config.output_dir, f'cotrain-plm-generate-node-{i}'),
            'strategy': config.trainer.strategy,
            'precision': '16-mixed',
            'num_sanity_val_steps': 0
        }
        embedding_writer_callback = EmbeddingWriter(
            write_path=os.path.join(config.output_dir, f'cotrain-node-embedding-{i}.pt'), 
            write_interval='epoch', 
            whitening_dim=config.gnn.model.hidden_dim, 
            is_node=True
        )
        node_text_trainer_parameters['callbacks'] = [embedding_writer_callback]
        node_text_trainer = pl.Trainer(**node_text_trainer_parameters)
        node_text_datamodule = CoSTNodeTextDataModule.from_config(config)
        node_text_trainer.predict(plm_model, node_text_datamodule)
    
        relation_text_trainer_parameters = {
            'accelerator': config.trainer.accelerator,
            'devices': config.trainer.devices,
            'default_root_dir': os.path.join(config.output_dir, f'cotrain-plm-generate-relation-{i}'),
            'strategy': config.trainer.strategy,
            'precision': '16-mixed',
            'num_sanity_val_steps': 0
        }
        embedding_writer_callback = EmbeddingWriter(
            write_path=os.path.join(config.output_dir, f'cotrain-relation-embedding-{i}.pt'), 
            write_interval='epoch', 
            whitening_dim=config.gnn.model.hidden_dim, 
            is_node=False
        )
        relation_text_trainer_parameters['callbacks'] = [embedding_writer_callback]
        relation_text_trainer = pl.Trainer(**relation_text_trainer_parameters)
        relation_text_datamodule = CoSTRelationTextDataModule.from_config(config)
        relation_text_trainer.predict(plm_model, relation_text_datamodule)
        
        plm_model.set_node_embedding_path(os.path.join(config.output_dir,  f'cotrain-node-embedding-{i}.pt'))
        plm_model.set_relation_embedding_path(os.path.join(config.output_dir, f'cotrain-relation-embedding-{i}.pt'))
        plm_model.set_inductive_node_embedding_path(os.path.join(config.output_dir, f'inductive-pretrain-node-embedding-{i}.pt'))
        plm_model.set_inductive_relation_embedding_path(os.path.join(config.output_dir, f'inductive-pretrain-relation-embedding-{i}.pt'))
        gnn_model.set_node_embedding_path(os.path.join(config.output_dir, f'whitening-cotrain-node-embedding-{i}.pt'))
        gnn_model.set_relation_embedding_path(os.path.join(config.output_dir, f'whitening-cotrain-relation-embedding-{i}.pt'))
        gnn_model.set_inductive_node_embedding_path(os.path.join(config.output_dir, f'inductive-whitening-cotrain-node-embedding-{i}.pt'))
        gnn_model.set_inductive_relation_embedding_path(os.path.join(config.output_dir, f'inductive-whitening-cotrain-relation-embedding-{i}.pt'))
        
        plm_model.predict_embedding(False)
        plm_pseudo_trainer_parameters = {
            'accelerator': config.trainer.accelerator,
            'devices': config.trainer.devices,
            'default_root_dir': os.path.join(config.output_dir, f'cotrain-plm-predict-{i}'),
            'strategy': config.trainer.strategy,
            'precision': '16-mixed',
            'num_sanity_val_steps': 0
        }
        plm_pseudo_writer_callback = PseudoFactWriter(
            write_path=os.path.join(config.output_dir, f'plm-pseudo-fact-{i}.pt'), 
            write_interval='epoch',
        )
        plm_pseudo_trainer_parameters['callbacks'] = [plm_pseudo_writer_callback]
        plm_pseudo_trainer = pl.Trainer(**plm_pseudo_trainer_parameters)
        plm_pseudo_datamodule = CoSTPseudoFactDataModule.from_config(config)
        plm_pseudo_trainer.predict(plm_model, plm_pseudo_datamodule) 
        
        gnn_model.set_pseudo_fact_path(os.path.join(config.output_dir, f'plm-pseudo-fact-{i}.pt'))
        gnn_checkpoint_monitor = ModelCheckpoint(
            dirpath=os.path.join(config.output_dir, f'cotrain-plm-{i}', 'checkpoint'),
            filename='{epoch}-{step}',
            monitor='valid_mrr',
            mode='max',
            save_top_k=1,
            every_n_epochs=1,
        )
        gnn_trainer_parameters = {
            'accelerator': config.trainer.accelerator,
            'devices': config.trainer.devices,
            'default_root_dir': os.path.join(config.output_dir, f'cotrain-plm-{i}'),
            'strategy': config.trainer.strategy,
            'precision': 32,
            'callbacks': [gnn_checkpoint_monitor],
            'max_epochs': 3,
            'num_sanity_val_steps': 0
        }
        if config.wandb.use:
            gnn_wandb_logger = WandbLogger(
                name=config.wandb.run_name + f'_cotrain-gnn-{i}',
                project=config.wandb.project_name,
                config=config,
                save_dir=os.path.join(config.output_dir, f'cotrain-plm-{i}', 'wandb')
            )
            gnn_trainer_parameters["logger"] = gnn_wandb_logger
        gnn_trainer = pl.Trainer(**gnn_trainer_parameters)
        gnn_trainer.fit(gnn_model, gnn_datamodule)
        gnn_trainer.test(gnn_model, gnn_datamodule)


if __name__ == '__main__':
    run()