import json
import os
import copy
from dataclasses import dataclass
from typing import Optional, Dict
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.distributed as dist

from transformers import AutoModel, PreTrainedModel, GPTJForCausalLM
from transformers.modeling_outputs import ModelOutput
from .gptj_encoder import GPT2EncoderBlock

## ------- Modified.
import sys
sys.path.append("..")
## ------- Modified.

from ancetele.arguments import ModelArguments, DataArguments
from ancetele.arguments import DenseTrainingArguments as TrainingArguments


import logging
logger = logging.getLogger(__name__)



@dataclass
class DenseOutput(ModelOutput):
    q_reps: Tensor = None
    p_reps: Tensor = None
    loss: Tensor = None
    scores: Tensor = None


class Residual_Encoder(nn.Module):
    
    def __init__(self, config, n_layer):
        super(Residual_Encoder, self).__init__()

        self.encoder_block = nn.ModuleList(
            [GPT2EncoderBlock(config=config) for i in range(n_layer)]
        )
        
        
    def forward(self, tgt, memory, tgt_pad_mask):
        
        position_ids = self.get_position_ids(tgt)
        
        attention_mask = tgt_pad_mask.view(tgt.size(0), -1)
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float16).min
        
        for layer_module in self.encoder_block:
            tgt=layer_module(
                tgt=tgt, 
                memory=memory,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )
        
        reps = self.mean_pooling(tgt, tgt_pad_mask)
        return reps
    
    
    def get_position_ids(self, inputs_embeds):
        
        past_length=0
        device = inputs_embeds.device
        input_shape = inputs_embeds.size()[:-1]
        
        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
        return position_ids
    
    
    def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    


        



        
class BiGPTJDenseModel(nn.Module):
    def __init__(
            self,
            clm: PreTrainedModel,
            residual_encoder: nn.Module,
            config,
            model_args: ModelArguments = None,
            data_args: DataArguments = None,
            train_args: TrainingArguments = None,    
            tokenizer: OrderedDict = None,
    ):
        super().__init__()
        
        self.config = config
        
        self.clm = clm
        self.residual_encoder = residual_encoder
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')

        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args
        self.pad_token_id = tokenizer.pad_token_id
        self.bottom_layer_num = model_args.bottom_layer_num
        
        ## Cosine Func
        self.cosine_scale = model_args.cosine_scale
        
        logger.info("Cosine Func Temperature: {}".format(self.cosine_scale))

        if train_args.negatives_x_device:
            if not dist.is_initialized():
                raise ValueError('Distributed training has not been initialized for representation all gather.')
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def forward(
            self,
            query: Dict[str, Tensor] = None,
            passage: Dict[str, Tensor] = None,
    ):

        ## *************************************************
        ## qry
        q_reps = None
        if query is not None:
            qry_outs = self.get_clm_outputs(query["input_ids"], query["attention_mask"])
            q_reps = self.residual_encoder(
                tgt=qry_outs[0],
                memory=qry_outs[-1],
                tgt_pad_mask=query["attention_mask"],
            )
            
        ## *************************************************
        ## psg
        p_reps = None
        if passage is not None:
            psg_outs = self.get_clm_outputs(passage["input_ids"], passage["attention_mask"])
            p_reps = self.residual_encoder(
                tgt=psg_outs[0],
                memory=psg_outs[-1],
                tgt_pad_mask=passage["attention_mask"],
            )

        if q_reps is None or p_reps is None:
            return DenseOutput(
                q_reps=q_reps,
                p_reps=p_reps
            )

        ## *************************************************
        ## dist
        if self.train_args.negatives_x_device:
            q_reps = self.dist_gather_tensor(q_reps)
            p_reps = self.dist_gather_tensor(p_reps)

        scores = torch.matmul(q_reps, p_reps.transpose(0, 1))
        ## *************************************************

        scores = scores.view(q_reps.size(0), -1)

        target = torch.arange(
            scores.size(0),
            device=scores.device,
            dtype=torch.long
        )
        target = target * (p_reps.size(0) // q_reps.size(0))
        ## *************************************************

        loss = self.cross_entropy(scores, target)
            
        ## *************************************************
        return DenseOutput(
            loss=loss,
            scores=scores,
        )

    
    def get_clm_outputs(self, input_ids, attention_mask):
        outputs = self.clm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        return outputs.hidden_states
    
#     def get_clm_embeddings(self, input_ids, attention_mask):

#         inputs_embeds = self.clm.transformer.wte(input_ids)
#         hidden_states = inputs_embeds

#         return hidden_states
    
    
    # def get_clm_outputs(self, inputs_embeds, attention_mask):
    #     outputs = self.clm(
    #         inputs_embeds=inputs_embeds,
    #         attention_mask=attention_mask,
    #         output_hidden_states=True,
    #         return_dict=True
    #     )
    #     return outputs.hidden_states[-1]
        

    def dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)

        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors
    
    

    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            data_args: DataArguments,
            train_args: TrainingArguments,
            tokenizer: OrderedDict,
            config,
            cache_dir,
            # **hf_kwargs,
    ):

        clm = GPTJForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
        residual_encoder = Residual_Encoder(
            config=config,
            n_layer=model_args.residual_num_layer,
        )

        if "xx" not in model_args.residual_encoder_name_or_path:
            residual_encoder.load_state_dict(
                torch.load(model_args.residual_encoder_name_or_path)['residual_encoder']
            )
            print("load residual decoder from ", model_args.residual_encoder_name_or_path)


        model = cls(
            clm=clm,
            config=config,
            residual_encoder=residual_encoder,
            model_args=model_args,
            data_args=data_args,
            train_args=train_args,
            tokenizer=tokenizer,
        )
        return model
    
    

    def save(self, output_dir: str):

        ## clm
        if self.clm.transformer.wte.weight.requires_grad:
            self.clm.save_pretrained(output_dir)

        ## residual enc
        ## deepspeed cannot auto convert state_dict to cpu in non-pretrained-model
        ckpt_name = "residual_encoder.ckpt"
        best_ckpt = {
            'residual_encoder': OrderedDict({k: v.cpu() for k, v in self.residual_encoder.state_dict().items()}),
            'ckpt_name': ckpt_name
        }
        # best_ckpt = {
        #     'residual_encoder': self.residual_encoder.state_dict(),
        #     'ckpt_name': ckpt_name
        # }
        torch.save(
            best_ckpt, 
            os.path.join(output_dir, ckpt_name)
        )

        
        

## ********************************************************
## Infer Model
## ********************************************************
class BiGPTJDenseModelForInference(BiGPTJDenseModel):

    def __init__(
            self,
            clm: PreTrainedModel,
            residual_encoder: nn.Module,
            config,
            tokenizer: OrderedDict,
            model_args: ModelArguments,
            **kwargs,
    ):
        nn.Module.__init__(self)
        
        self.clm = clm
        self.residual_encoder = residual_encoder
        self.config = config
        self.pad_token_id = tokenizer.pad_token_id
        self.bottom_layer_num = model_args.bottom_layer_num
        
        self.cosine_scale = model_args.cosine_scale
        
        logger.info("Cosine Func Temperature: {}".format(self.cosine_scale))
        
    @torch.no_grad()
    def get_clm_outputs(self, input_ids, attention_mask):
        return super(BiGPTJDenseModelForInference, self).get_clm_outputs(input_ids, attention_mask)
    
        

#     @torch.no_grad()
#     def get_clm_embeddings(self, input_ids, attention_mask):
#         return super(BiGPTJDenseModelForInference, self).get_clm_embeddings(input_ids, attention_mask)

#     @torch.no_grad()
#     def get_clm_outputs(self, inputs_embeds, attention_mask):
#         return super(BiGPTJDenseModelForInference, self).get_clm_outputs(inputs_embeds, attention_mask)
    

    def forward(
            self,
            text: Dict[str, Tensor] = None,
    ):
            
        outs = self.get_clm_outputs(text["input_ids"], text["attention_mask"])
        
        reps = self.residual_encoder(
            tgt=outs[0],
            memory=outs[-1],
            tgt_pad_mask=text["attention_mask"],
        )
        
        if self.cosine_scale is not None:
            reps = F.normalize(reps, p=2, dim=-1)
        
        return reps

            
    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            data_args: DataArguments,
            tokenizer: OrderedDict,
            config,
            cache_dir,
            # **hf_kwargs,
    ):
        
        clm = GPTJForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
        
        residual_encoder = Residual_Encoder(
            config=config,
            n_layer=model_args.residual_num_layer,
        )
        residual_encoder.load_state_dict(
            torch.load(model_args.residual_encoder_name_or_path)['residual_encoder']
        )
        
        residual_encoder.to(clm.device)
        
        print("load residual encoder from ", model_args.residual_encoder_name_or_path)
        
        model = cls(
            clm=clm,
            residual_encoder=residual_encoder,
            config=config,
            tokenizer=tokenizer,
            model_args=model_args,
        )
        return model