import os

import torch.distributed
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import logging
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from transformers import (
    AutoModel, 
    AutoTokenizer,
    get_constant_schedule,
    get_linear_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup
)
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import DictConfig

from src.metrics import MRMetric, MRRMetric, HitsMetric


logger = logging.getLogger(__name__)

@rank_zero_only
def info(msg):
    logger.info(msg)
    
@rank_zero_only
def error(msg):
    logger.error(msg)


class PLMReasonModel(nn.Module):
    def __init__(self, 
                 node_encoder: AutoModel, 
                 relation_encoder: AutoModel,
                 tokenizer: AutoTokenizer,
                 max_length: int,
                 pooling_method: str):
        super().__init__()
        self.node_encoder = node_encoder
        self.relation_encoder = relation_encoder
        self.tokenizer =  tokenizer
        self.hidden_size = node_encoder.config.hidden_size
        self.max_length = max_length
        self.pooling_method = pooling_method
        
        # temperature = 1. / exp(contrastive_temperature)
        self.contrastive_temperature = nn.Parameter(torch.tensor([0.0]))
        
        self.fusion_net = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size),
                                        nn.ReLU(),
                                        nn.Linear(self.hidden_size, self.hidden_size))
        
    @classmethod
    def from_config(cls, config: DictConfig):
        if config.plm.model.name == 'minilm':
            # https://github.com/microsoft/unilm/tree/master/minilm
            encoder = AutoModel.from_pretrained('microsoft/MiniLM-L12-H384-uncased', cache_dir='./cache/')
            tokenizer = AutoTokenizer.from_pretrained('microsoft/MiniLM-L12-H384-uncased', cache_dir='./cache/')
        elif config.plm.model.name == 'deberta':
            # https://huggingface.co/microsoft/deberta-v3-base
            encoder = AutoModel.from_pretrained('deberta-base', cache_dir='./cache/')
            tokenizer = AutoTokenizer.from_pretrained('deberta-base', cache_dir='./cache/')
        elif config.plm.model.name == 'bert':
            encoder = AutoModel.from_pretrained('bert-base-uncased', cache_dir='./cache/')
            tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir='./cache/')
        
        return cls(
            encoder,
            deepcopy(encoder), 
            tokenizer,
            config.data.max_length,
            config.plm.model.pooling_method
        )
        
    def _pooling(self, last_hidden_states, attention_mask):
        if self.pooling_method == 'cls':
            emb = last_hidden_states[:, 0, :]
        elif self.pooling_method == 'max':
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).long()
            last_hidden_states[input_mask_expanded == 0] = -1e4
            emb = torch.max(last_hidden_states, 1)[0]
        elif self.pooling_method == 'mean':
            last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
            emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        # emb =  F.normalize(emb, dim=1)
        return emb
        
    def encode(self, input_texts, encoder):
        inputs = self.tokenizer(input_texts, 
                                max_length=self.max_length, 
                                truncation=True, 
                                padding=True,
                                return_tensors='pt').to(self.contrastive_temperature.device)
        
        outputs = encoder(**inputs)
        embedding = self._pooling(outputs.last_hidden_state, inputs.attention_mask)
        
        return embedding
    
    def query_encode(self, h_emb, r_emb):
        q_emb = self.fusion_net(h_emb * r_emb)
        
        return q_emb
    
    def compute_score(self, q_emb, t_emb):
        # q_emb = einops.repeat(q_emb, 'b d -> b n d', n=1)
        # t_emb = einops.rearrange(t_emb, 'b n d -> b d n')
        t_emb = einops.rearrange(t_emb, 'b d -> d b')
        score = torch.mm(q_emb, t_emb)# .squeeze(1)
        return score
        
    def forward(self, h_texts, r_texts, t_texts):
        # size is (batch size, dimension size of embedding)
        h_emb = self.encode(h_texts, self.node_encoder)
        r_emb = self.encode(r_texts, self.relation_encoder)
        q_emb = self.query_encode(h_emb, r_emb)
        
        batch_size = q_emb.size(0)
        # size is (batch size, num examples, dimension size of embedding)
        t_emb = self.encode(t_texts, self.node_encoder)
        # t_emb = einops.rearrange(t_emb, '(b n) d -> b n d', b=batch_size)
        
        if self.training and torch.distributed.is_initialized():
            all_q_emb = [None] * torch.distributed.get_world_size()
            all_t_emb = [None] * torch.distributed.get_world_size()
            torch.distributed.all_gather_object(all_q_emb, q_emb)
            torch.distributed.all_gather_object(all_t_emb, t_emb)
            all_q_emb[torch.distributed.get_rank()] = q_emb
            all_t_emb[torch.distributed.get_rank()] = t_emb
            all_q_emb = torch.cat([q.to(self.contrastive_temperature.device) for q in all_q_emb], dim=0)
            all_t_emb = torch.cat([t.to(self.contrastive_temperature.device) for t in all_t_emb], dim=0)
            torch.distributed.barrier()
            score = self.compute_score(all_q_emb, all_t_emb)
        else:
            score = self.compute_score(q_emb, t_emb)
        
        return score, h_emb, r_emb, q_emb, t_emb
    

class CoSTPLMLightningModule(pl.LightningModule):
    def __init__(self,
                 config: DictConfig,
                 model: PLMReasonModel,
                 pretrain: bool = True):
        super().__init__()
        self.config = config
        self.model = model
        self.pretrain = pretrain
        
        self._node_embedding_path = None
        self._relation_embedding_path = None
        self._inductive_node_embedding_path = None
        self._inductive_relation_embedding_path = None
        self._pseudo_fact_path = None
        self._predict_embedding = True
        
        self.mr_fn = MRMetric()
        self.mrr_fn = MRRMetric()
        self.hits1_fn = HitsMetric(topk=1)
        self.hits3_fn = HitsMetric(topk=3)
        self.hits10_fn = HitsMetric(topk=10)
        self.hits50_fn = HitsMetric(topk=50)
        self.hits100_fn = HitsMetric(topk=100)
        
    @classmethod
    def from_config(cls, config: DictConfig, pretrain: bool):
        model = PLMReasonModel.from_config(config)
        return cls(config, model, pretrain)
        
    def set_node_embedding_path(self, path):
        self._node_embedding_path = path
        
    def set_relation_embedding_path(self, path):
        self._relation_embedding_path = path
        
    def set_inductive_node_embedding_path(self, path):
        self._inductive_node_embedding_path = path
        
    def set_inductive_relation_embedding_path(self, path):
        self._inductive_relation_embedding_path = path
        
    def set_pseudo_fact_path(self, path):
        self._pseudo_fact_path = path
        
    def predict_embedding(self, tag: bool):
        self._predict_embedding = tag
        
    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        grouped_optimizer_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() if any([d in n for d in no_decay]) and p.requires_grad],
                'weight_decay': 0.0
            },
            {
                'params': [p for n, p in self.model.named_parameters() if not any([d in n for d in no_decay]) and p.requires_grad],
                'weight_decay': self.config.plm.weight_decay
            }
        ]
        optimizer = torch.optim.AdamW(
            grouped_optimizer_parameters,
            lr=self.config.plm.lr,
        )

        scheduler = get_linear_schedule_with_warmup(optimizer, 200, self.trainer.estimated_stepping_batches)
        scheduler = {
            'scheduler': scheduler, 
            'interval': 'step', 
            'frequency': 1
        }

        return [optimizer], [scheduler]
            
    def on_fit_start(self):
        if getattr(self, 'node_embedding', None) is None and self._node_embedding_path is not None:
            self.node_embedding = torch.load(
                self._node_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
            self.relation_embedding = torch.load(
                self._relation_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
        if getattr(self, 'pseudo_fact', None) is None and self._pseudo_fact_path is not None:
            self.pseudo_fact = torch.load(self._pseudo_fact_path, map_location=lambda storage, loc: storage)
        
    def training_step(self, batch):
        if self.pretrain:
            scores, h_emb, _, q_emb, _ = self.model.forward(batch.h_text, batch.r_text, batch.t_text)
            
            if torch.distributed.is_initialized():
                all_t = [None] * torch.distributed.get_world_size()
                all_filter_mask = [None] * torch.distributed.get_world_size()
                all_q_emb = [None] * torch.distributed.get_world_size()
                all_h_emb = [None] * torch.distributed.get_world_size()
                torch.distributed.all_gather_object(all_t, batch.t)
                torch.distributed.all_gather_object(all_filter_mask, batch.filter_mask)
                torch.distributed.all_gather_object(all_q_emb, q_emb)
                torch.distributed.all_gather_object(all_h_emb, h_emb)
                all_t[torch.distributed.get_rank()] = batch.t
                all_filter_mask[torch.distributed.get_rank()] = batch.filter_mask
                all_q_emb[torch.distributed.get_rank()] = q_emb
                all_h_emb[torch.distributed.get_rank()] = h_emb
                all_t = torch.cat([t.to(self.device) for t in all_t], dim=0)
                all_filter_mask = torch.cat([f.to(self.device) for f in all_filter_mask], dim=0)
                all_q_emb = torch.cat([q.to(self.device) for q in all_q_emb], dim=0)
                all_h_emb = torch.cat([t.to(self.device) for t in all_h_emb], dim=0)
                torch.distributed.barrier()
            else:
                all_t, all_h_emb, all_q_emb, all_filter_mask = batch.t, h_emb, q_emb, batch.filter_mask
            
            # compute loss
            logits = scores
            target = torch.arange(logits.size(0), dtype=torch.long, device=self.device)
            logits_mask = torch.gather(all_filter_mask, dim=1, index=einops.repeat(all_t, 'b -> n b', n=all_filter_mask.size(0)))
            logits_mask[target, target] = 0
            logits[logits_mask.bool()] = -1e-4
            self_logits = (all_q_emb * all_h_emb).sum(dim=1, keepdim=True)
            logits = torch.cat([logits, self_logits], dim=1)
            loss = F.cross_entropy(logits / (1. / self.model.contrastive_temperature.exp()), target)
        else:
            num_node = batch.filter_mask.size(1)
            pseudo_target = torch.zeros_like(batch.filter_mask)
            for i, (h, r) in enumerate(zip(batch.h, batch.r)):
                pseudo_key = self.pseudo_fact[:, 1] * num_node + self.pseudo_fact[:, 0]
                key = r.item() * num_node + h.item()
                pseudo_target[i, self.pseudo_fact[pseudo_key == key, 2]] = 1
                pseudo_target[i, batch.t[i].item()] = 1
            
            pseudo_t = torch.multinomial(pseudo_target.float(), 1).to(self.device)
            dataset = self.trainer.train_dataloader.dataset
            h_text = batch.h_text + batch.h_text
            r_text = batch.r_text + batch.r_text
            t_text = batch.t_text + [dataset.graph.get_node_textual_information(
                i.item(), ind=dataset.split_data_dict['inductive'])[1] for i in pseudo_t]
            t = torch.cat([batch.t, pseudo_t.view(-1)], dim=0)
            
            pseudo_target = pseudo_target.to(self.device).bool() | batch.filter_mask.bool()
            pseudo_target = torch.cat([pseudo_target, pseudo_target], dim=0)
            
            scores, h_emb, _, q_emb, _ = self.model.forward(h_text, r_text, t_text)
            
            if torch.distributed.is_initialized():
                all_t = [None] * torch.distributed.get_world_size()
                all_filter_mask = [None] * torch.distributed.get_world_size()
                all_q_emb = [None] * torch.distributed.get_world_size()
                all_h_emb = [None] * torch.distributed.get_world_size()
                torch.distributed.all_gather_object(all_t, t)
                torch.distributed.all_gather_object(all_filter_mask, pseudo_target.long())
                torch.distributed.all_gather_object(all_q_emb, q_emb)
                torch.distributed.all_gather_object(all_h_emb, h_emb)
                all_t[torch.distributed.get_rank()] = t
                all_filter_mask[torch.distributed.get_rank()] = pseudo_target.long()
                all_q_emb[torch.distributed.get_rank()] = q_emb
                all_h_emb[torch.distributed.get_rank()] = h_emb
                all_t = torch.cat([t.to(self.device) for t in all_t], dim=0)
                all_filter_mask = torch.cat([f.to(self.device) for f in all_filter_mask], dim=0)
                all_q_emb = torch.cat([q.to(self.device) for q in all_q_emb], dim=0)
                all_h_emb = torch.cat([t.to(self.device) for t in all_h_emb], dim=0)
                torch.distributed.barrier()
            else:
                all_t, all_h_emb, all_q_emb, all_filter_mask = batch.t, h_emb, q_emb, batch.filter_mask
            
            logits = scores
            target = torch.arange(logits.size(0), dtype=torch.long, device=self.device)
            logits_mask = torch.gather(all_filter_mask, dim=1, index=einops.repeat(all_t, 'b -> n b', n=all_filter_mask.size(0)))
            logits_mask[target, target] = 0
            logits[logits_mask.bool()] = -1e-4
            self_logits = (all_q_emb * all_h_emb).sum(dim=1, keepdim=True)
            logits = torch.cat([logits, self_logits], dim=1)
            loss = F.cross_entropy(logits / (1. / self.model.contrastive_temperature.exp()), target)
        
        self.log('loss', loss.detach(), prog_bar=True)
        self.log('memory', torch.cuda.max_memory_allocated()/(1024**3), prog_bar=True)
        return loss
    
    def _compute_metrics(self, mode='valid'):
        mr = self.mr_fn.compute()
        mrr = self.mrr_fn.compute()
        hits1 = self.hits1_fn.compute()
        hits3 = self.hits3_fn.compute()
        hits10 = self.hits10_fn.compute()
        hits50 = self.hits50_fn.compute()
        hits100 = self.hits100_fn.compute()

        self.mr_fn.reset()
        self.mrr_fn.reset()
        self.hits1_fn.reset()
        self.hits3_fn.reset()
        self.hits10_fn.reset()
        self.hits50_fn.reset()
        self.hits100_fn.reset()

        self.log(f'{mode}_mr', mr, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_mrr', mrr, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_hits1', hits1, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_hits3', hits3, prog_bar=False, sync_dist=True)
        self.log(f'{mode}_hits10', hits10, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_hits50', hits3, prog_bar=False, sync_dist=True)
        self.log(f'{mode}_hits100', hits10, prog_bar=False, sync_dist=True)
        
        return mr, mrr, hits1, hits3, hits10, hits50, hits100

    def validation_step(self, batch):
        scores, *_ = self.model.forward(batch.h_text, batch.r_text, batch.t_text)
        batch_index = torch.arange(len(batch.t_text), device=self.device)
        mask = torch.gather(batch.filter_mask, dim=1, index=einops.repeat(batch.t, 'b -> n b', n=len(batch.t_text)))
        mask[batch_index, batch_index] = 0
        all_ranks = torch.sum((scores.masked_fill(mask.bool(), -1e4) >= scores[batch_index, batch_index].unsqueeze(1)), dim=1)

        self.mr_fn.update(all_ranks)
        self.mrr_fn.update(all_ranks)
        self.hits1_fn.update(all_ranks)
        self.hits3_fn.update(all_ranks)
        self.hits10_fn.update(all_ranks)
        self.hits50_fn.update(all_ranks)
        self.hits100_fn.update(all_ranks)
        
        return None, all_ranks
    
    def on_validation_epoch_end(self):
        from tabulate import tabulate
        mr, mrr, hits1, hits3, hits10, hits50, hits100 = self._compute_metrics(mode='valid')
        metrics = {
            'mr': [mr],
            'mrr': [mrr],
            'hits1': [hits1],
            'hits3': [hits3],
            'hits10': [hits10],
            'hits50': [hits50],
            'hits100': [hits100]
        }
        info(f'Valid Metircs at Epoch {self.trainer.current_epoch}: \n' + tabulate(metrics, headers='keys', tablefmt='grid'))
        
    def on_test_start(self):
        is_inductive = self.trainer.test_dataloaders.dataset.split_data_dict['inductive']
        if is_inductive:
            self.node_embedding = torch.load(
                self._inductive_node_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
            self.relation_embedding = torch.load(
                self._inductive_relation_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
        else:
            self.node_embedding = torch.load(
                self._node_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
            self.relation_embedding = torch.load(
                self._relation_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)

    def test_step(self, batch):
        h_emb = self.node_embedding[batch.h]
        r_emb = self.relation_embedding[batch.r]
        q_emb = self.model.query_encode(h_emb, r_emb)
        all_scores = self.model.compute_score(q_emb, self.node_embedding)
        
        filter_mask = batch.filter_mask
        if batch.negative_target_node is not None:
            filter_mask = torch.ones_like(filter_mask)
            filter_mask = filter_mask.scatter(1, batch.negative_target_node, 0).bool()
            
        # compute ranks for each answer nodes
        answer_nodes = torch.stack([torch.arange(batch.h.size(0), device=self.device), batch.t], dim=1)
        answer_scores = all_scores[answer_nodes[:, 0], answer_nodes[:, 1]]
        expanded_filter_mask = filter_mask[answer_nodes[:, 0]].bool()
        batch_all_scores = all_scores[answer_nodes[:, 0]]
        all_ranks = torch.sum((batch_all_scores >= answer_scores.unsqueeze(1)) & (~expanded_filter_mask), dim=1) + 1
        
        self.mr_fn.update(all_ranks)
        self.mrr_fn.update(all_ranks)
        self.hits1_fn.update(all_ranks)
        self.hits3_fn.update(all_ranks)
        self.hits10_fn.update(all_ranks)
        self.hits50_fn.update(all_ranks)
        self.hits100_fn.update(all_ranks)
        
        return all_scores, all_ranks
        
    def on_test_epoch_end(self):
        from tabulate import tabulate
        mr, mrr, hits1, hits3, hits10, hits50, hits100 = self._compute_metrics(mode='test')
        metrics = {
            'mr': [mr],
            'mrr': [mrr],
            'hits1': [hits1],
            'hits3': [hits3],
            'hits10': [hits10],
            'hits50': [hits50],
            'hits100': [hits100]
        }
        info('Test Metircs: \n' + tabulate(metrics, headers='keys', tablefmt='grid'))
        
    def on_test_end(self):
        del self.node_embedding
        del self.relation_embedding
        
    def on_predict_start(self):
        if not self._predict_embedding:
            self.node_embedding = torch.load(
                    self._node_embedding_path, 
                    map_location=lambda storage, loc: storage
            ).to(self.device)
            self.relation_embedding = torch.load(
                self._relation_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)

        
    def _predict_embedding_step(self, batch):
        # for generating node and relation embeddings
        indexs, texts, is_node, split = batch

        if is_node:
            emb = self.model.encode(texts, self.model.node_encoder)
        else:
            emb = self.model.encode(texts, self.model.relation_encoder)
            
        return (
            torch.tensor([0 if split == 'train' else 1 for _ in range(len(texts))]), 
            indexs, 
            emb, 
        )
    
    def _predict_scores_step(self, batch):
        h, r, _ = batch
        
        h_emb = self.node_embedding[h]
        r_emb = self.relation_embedding[r]
        q_emb = self.model.query_encode(h_emb, r_emb)
        all_scores = self.model.compute_score(q_emb, self.node_embedding)
        
        all_scores = F.softmax(all_scores, dim=-1)
        pseudo_target = torch.bernoulli(all_scores).long()
        pseudo_target = pseudo_target * (all_scores > 0.75)
        
        query_index = torch.stack([h, r], dim=1)
        pseudo_t = pseudo_target.nonzero()
        pseudo_fact = torch.cat([query_index[pseudo_t[:, 0]], pseudo_t[:, 1:]], dim=1)
        
        # num_node = batch.filter_mask.size(1)
        
        return pseudo_fact

    def predict_step(self, batch):
        if self._predict_embedding:
            return self._predict_embedding_step(batch)
        else:
            return self._predict_scores_step(batch)
    
    def on_predict_end(self):
        if not self._predict_embedding:
            del self.node_embedding
            del self.relation_embedding