import copy
import json
import os
import logging
from dataclasses import dataclass
from typing import Dict, Optional

import torch
from torch import nn, Tensor
import torch.distributed as dist
from transformers import PreTrainedModel, AutoModel
from transformers.file_utils import ModelOutput
import math

from sup_con_xmc.arguments import (
    ModelArguments,
    MyTrainingArguments as TrainingArguments
)

logger = logging.getLogger(__name__)


@dataclass
class EncoderOutput(ModelOutput):
    q_reps: Optional[Tensor] = None
    p_reps: Optional[Tensor] = None
    loss: Optional[Tensor] = None
    scores: Optional[Tensor] = None


class EncoderModel(nn.Module):
    TRANSFORMER_CLS = AutoModel
    _keys_to_ignore_on_save = []

    def __init__(self,
        lm_q: PreTrainedModel,
        lm_p: PreTrainedModel,
        pooling_type: str = "avg",
        skip_l2norm: bool = False,
        temperature: float = 0.05,
        untie_encoder: bool = False,
        negatives_x_device: bool = False,
        use_q_neg: bool =True,
    ):
        super().__init__()
        if pooling_type not in ["cls", "avg", "last"]:
            raise ValueError(f"pooling_type={pooling_type} is not support yet!")
        self.lm_q = lm_q
        self.lm_p = lm_p
        self.pooling_type = pooling_type
        self.skip_l2norm = skip_l2norm
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss()
        self.negatives_x_device = negatives_x_device
        self.untie_encoder = untie_encoder
        self.use_q_neg = use_q_neg
        if self.negatives_x_device:
            if not dist.is_initialized():
                raise ValueError('Distributed training has not been initialized for representation all gather.')
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        # for multiple positive instances
        self.supc_target = None
        self.dual_target = None
        self.dual_mask = None
        self.supc_pos_mask = None

    def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, key: Dict[str, Tensor] = None):
        if self.training:
            query_pos_labels = query.pop("pos_labels")
            if passage:
                passage_ids = passage.pop("id")
        q_reps = self.encode_query(query)
        p_reps = self.encode_passage(passage)

        # for inference
        if not self.training and (q_reps is None or p_reps is None):
            return EncoderOutput(
                q_reps=q_reps,
                p_reps=p_reps
            )

        # for training
        if self.training:
            if self.negatives_x_device:
                q_reps = self._dist_gather_tensor(q_reps)
                p_reps = self._dist_gather_tensor(p_reps)
                if passage:
                    passage_ids = self._dist_gather_tensor(passage_ids)

            if self.use_q_neg:
                if self.negatives_x_device:
                    neg_query_pos_labels = self._dist_gather_tensor(query_pos_labels)
                neg_query_false_neg_mask = self._get_single_gpu_false_negative(query_pos_labels, neg_query_pos_labels)
                if self.negatives_x_device:
                    neg_query_false_neg_mask = self._dist_gather_tensor(neg_query_false_neg_mask)
                neg_query_scores = self.compute_similarity(q_reps, q_reps).view(q_reps.size(0), -1)
            
            if passage:
                dual_scores = self.compute_similarity(q_reps, p_reps)
                dual_scores = dual_scores.view(q_reps.size(0), -1)
                dual_group_size = p_reps.size(0) // q_reps.size(0)

                # for false neg mask
                dual_false_neg_mask = self._get_single_gpu_false_negative(query_pos_labels, passage_ids)
                if self.negatives_x_device:
                    dual_false_neg_mask = self._dist_gather_tensor(dual_false_neg_mask) #[p_reps.size(0), L, p_reps.size(0)]
                
                if self.use_q_neg:
                    dual_false_neg_mask = torch.cat([dual_false_neg_mask, neg_query_false_neg_mask], dim=1)
                    dual_scores = torch.cat([dual_scores, neg_query_scores], dim=1)

                if self.dual_target is None or self.dual_target.shape != dual_scores.shape:
                    self.dual_target = torch.zeros_like(dual_scores)
                    for bid in torch.arange(self.dual_target.size(0)):
                        self.dual_target[bid,bid*dual_group_size] = 1.0
                
                # the targets themselves are not false negative
                dual_false_neg_mask = dual_false_neg_mask.masked_fill_(self.dual_target > 0,0)
            
            scores = dual_scores
            loss = self.compute_sample_dec_loss(scores / self.temperature, self.dual_target, dual_false_neg_mask)
            if self.negatives_x_device:
                loss = loss * self.world_size  # counter average weight reduction
        # for eval
        else:
            scores = self.compute_similarity(q_reps, p_reps)
            loss = None
        return EncoderOutput(
            loss=loss,
            scores=scores,
            q_reps=q_reps,
            p_reps=p_reps,
        )

    def encode_passage(self, psg):
        raise NotImplementedError('EncoderModel is an abstract class')

    def encode_query(self, qry):
        raise NotImplementedError('EncoderModel is an abstract class')

    def compute_similarity(self, q_reps, p_reps):
        return torch.matmul(q_reps, p_reps.transpose(0, 1))

    def compute_loss(self, scores, target):
        return self.cross_entropy(scores, target)
    
    def compute_sample_dec_loss(self, scores, target, false_neg_mask):
        nom = (-scores*target).sum()
        den = (torch.logsumexp(scores.masked_fill_(false_neg_mask > 0, -100.0),dim=1) * target.sum(dim=1)).sum()
        return (nom + den)/target.sum()
    
    def _get_single_gpu_false_negative(self, query_pos_labels, key_pos_labels, K=24):
        key_pos_labels = key_pos_labels.view(key_pos_labels.size(0), -1)
        key_pos_labels = key_pos_labels.unsqueeze(0).unsqueeze(2) # shape: [1, k_reps.size(0), 1, L]
        gathered_false_neg_mask = []
        for i in torch.arange(math.ceil(query_pos_labels.size(0)/K)):
            cur_query_pos_labels = query_pos_labels[i*K:(i+1)*K]
            cur_query_pos_labels = cur_query_pos_labels.unsqueeze(1).unsqueeze(3)
            false_neg_mask = (cur_query_pos_labels == key_pos_labels).sum(-1).sum(-1)
            gathered_false_neg_mask.append(false_neg_mask)
        return torch.cat(gathered_false_neg_mask, dim=0)

    def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)

        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors

    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            train_args: TrainingArguments,
            **hf_kwargs,
    ):
        # load local
        if os.path.isdir(model_args.model_name_or_path):
            if model_args.untie_encoder:
                _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model')
                _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model')
                if not os.path.exists(_qry_model_path):
                    _qry_model_path = model_args.model_name_or_path
                    _psg_model_path = model_args.model_name_or_path
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = cls.TRANSFORMER_CLS.from_pretrained(
                    _qry_model_path,
                    **hf_kwargs
                )
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = cls.TRANSFORMER_CLS.from_pretrained(
                    _psg_model_path,
                    **hf_kwargs
                )
            else:
                lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
                lm_p = lm_q
        # load pre-trained
        else:
            lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
            lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q

        model = cls(
            lm_q=lm_q,
            lm_p=lm_p,
            pooling_type=model_args.pooling_type,
            skip_l2norm=model_args.skip_l2norm,
            temperature=train_args.temperature,
            negatives_x_device=train_args.negatives_x_device,
            untie_encoder=model_args.untie_encoder,
            use_q_neg=train_args.use_q_neg,
        )
        return model

    @classmethod
    def load(cls, model_name_or_path, **hf_kwargs):
        untie_encoder = True
        pooler_config = {
            "pooling_type": "avg",
            "skip_l2norm": False,
            "temperature": 0.05,
        }
        if os.path.isdir(model_name_or_path):
            _qry_model_path = os.path.join(model_name_or_path, 'query_model')
            _psg_model_path = os.path.join(model_name_or_path, 'passage_model')
            if os.path.exists(_qry_model_path):
                logger.info(f'found separate weight for query/passage encoders')
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = cls.TRANSFORMER_CLS.from_pretrained(
                    _qry_model_path,
                    **hf_kwargs
                )
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = cls.TRANSFORMER_CLS.from_pretrained(
                    _psg_model_path,
                    **hf_kwargs
                )
                untie_encoder = False
            else:
                logger.info(f'try loading tied weight')
                logger.info(f'loading model weight from local: {model_name_or_path}')
                lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
                lm_p = lm_q
            # override pooler config from locally-saved directory
            with open(f"{model_name_or_path}/pooler_config.json", "r") as fin:
                pooler_config = json.loads(fin.read())
        else:
            logger.info(f'try loading tied weight')
            logger.info(f'loading model weight from HF-hub: {model_name_or_path}')
            lm_q = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, **hf_kwargs)
            lm_p = lm_q
        model = cls(
            lm_q=lm_q,
            lm_p=lm_p,
            pooling_type=pooler_config["pooling_type"],
            skip_l2norm=pooler_config["skip_l2norm"],
            temperature=pooler_config["temperature"],
            untie_encoder=untie_encoder,
        )
        return model

    def save(self, output_dir: str):
        # save encoder
        if self.untie_encoder:
            os.makedirs(os.path.join(output_dir, 'query_model'))
            os.makedirs(os.path.join(output_dir, 'passage_model'))
            self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
            self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
        else:
            self.lm_q.save_pretrained(output_dir)

        # save pooler config
        pooler_config = {
            "pooling_type": self.pooling_type,
            "skip_l2norm": self.skip_l2norm,
            "temperature": self.temperature,
        }
        with open("{}/pooler_config.json".format(output_dir), "w") as fout:
            fout.write(json.dumps(pooler_config, indent=True))


class DenseModel(EncoderModel):
    def get_pooled_emb(self, hidden_states, attention_mask):
        if self.pooling_type == "cls":
            embeddings = hidden_states[:, 0]
        elif self.pooling_type == "avg":
            hidden_states = hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
            embeddings = hidden_states.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        elif self.pooling_type == "last":
            # emb is the last token representation that is not padding
            sequence_lengths = attention_mask.sum(dim=1)
            last_token_indices = sequence_lengths - 1
            embeddings = hidden_states[torch.arange(p_hidden.size(0)), last_token_indices]
        # end of pooling logic
        if not self.skip_l2norm:
            embeddings = nn.functional.normalize(embeddings, p=2, dim=-1)
        return embeddings

    def encode_passage(self, psg):
        if psg is None:
            return None
        psg_out = self.lm_p(**psg, return_dict=True)
        p_hidden = psg_out.last_hidden_state
        p_reps = self.get_pooled_emb(p_hidden, psg["attention_mask"])
        return p_reps

    def encode_query(self, qry):
        if qry is None:
            return None
        qry_out = self.lm_q(**qry, return_dict=True)
        q_hidden = qry_out.last_hidden_state
        return self.get_pooled_emb(q_hidden, qry["attention_mask"])
