import json
import os
import copy
from dataclasses import dataclass
from typing import Optional, Dict
from collections import OrderedDict

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import torch.distributed as dist

from transformers import AutoModel, PreTrainedModel, GPT2LMHeadModel, GPTJForCausalLM
from transformers.modeling_outputs import ModelOutput

from .vanilla_transformer_encoder import Vanilla_Transformer_Encoder
from .vanilla_gptj_encoder import VanillaGPTJEncoderBlock

## ------- Modified.
import sys
sys.path.append("..")
## ------- Modified.

from ancetele.arguments import ModelArguments, DataArguments
from ancetele.arguments import DenseTrainingArguments as TrainingArguments


import logging
logger = logging.getLogger(__name__)



@dataclass
class DenseOutput(ModelOutput):
    q_reps: Tensor = None
    p_reps: Tensor = None
    loss: Tensor = None
    scores: Tensor = None

## *********************************
## GPT2-Cross-Attn
## *********************************
class GPT2_Residual_Encoder(nn.Module):
    
    def __init__(self, input_size, n_head, n_layer):
        super(GPT2_Residual_Encoder, self).__init__()

        self.encoder_block = nn.ModuleList(
            [Vanilla_Transformer_Encoder(d_model=input_size, nhead=n_head, batch_first=True) for i in range(n_layer)]
        )
        
        
    def forward(self, tgt, tgt_pad_mask):
        
        for layer_module in self.encoder_block:
            tgt = layer_module(
                tgt=tgt,
                tgt_key_padding_mask=~tgt_pad_mask.bool(),
            )
        
        reps = self.mean_pooling(tgt, tgt_pad_mask)

        return reps
    
    
    def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    
## *********************************
## GPT-j-Cross-Attn
## *********************************
class GPTJ_Residual_Encoder(nn.Module):
    
    def __init__(self, config, n_layer):
        super(GPTJ_Residual_Encoder, self).__init__()

        self.encoder_block = nn.ModuleList(
            [VanillaGPTJEncoderBlock(config=config) for i in range(n_layer)]
        )
        
        
    def forward(self, tgt, tgt_pad_mask):
        
        position_ids = self.get_position_ids(tgt)
        
        attention_mask = tgt_pad_mask.view(tgt.size(0), -1)
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float16).min
        
        for layer_module in self.encoder_block:
            tgt=layer_module(
                tgt=tgt,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )
        
        reps = self.mean_pooling(tgt, tgt_pad_mask)
        return reps
    
    
    def get_position_ids(self, inputs_embeds):
        
        past_length=0
        device = inputs_embeds.device
        input_shape = inputs_embeds.size()[:-1]
        
        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
        return position_ids
    
    
    def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    


        
class VanillaDenseModel(nn.Module):
    def __init__(
            self,
            clm: PreTrainedModel,
            residual_encoder: nn.Module,
            config,
            model_args: ModelArguments = None,
            data_args: DataArguments = None,
            train_args: TrainingArguments = None,    
            tokenizer: OrderedDict = None,
    ):
        super().__init__()
        
        self.config = config
        
        self.clm = clm
        self.residual_encoder = residual_encoder
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')

        self.model_args = model_args
        self.cosine_scale = model_args.cosine_scale
        self.train_args = train_args
        self.data_args = data_args
        self.pad_token_id = tokenizer.pad_token_id

        if train_args.negatives_x_device:
            if not dist.is_initialized():
                raise ValueError('Distributed training has not been initialized for representation all gather.')
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def forward(
            self,
            query: Dict[str, Tensor] = None,
            passage: Dict[str, Tensor] = None,
    ):

        ## *************************************************
        ## qry
        q_reps = None
        if query is not None:
            if "gpt-j" in self.model_args.model_name_or_path:
                qry_embs = self.get_gptj_embeddings(query["input_ids"], query["attention_mask"])
            else:
                qry_embs = self.get_gpt2_embeddings(query["input_ids"], query["attention_mask"])
                
            

            q_reps = self.residual_encoder(
                tgt=qry_embs,
                tgt_pad_mask=query["attention_mask"],
            )

        ## *************************************************
        ## psg
        p_reps = None
        if passage is not None:
            if "gpt-j" in self.model_args.model_name_or_path:
                psg_embs = self.get_gptj_embeddings(passage["input_ids"], passage["attention_mask"])
            else:
                psg_embs = self.get_gpt2_embeddings(passage["input_ids"], passage["attention_mask"])

            p_reps = self.residual_encoder(
                tgt=psg_embs,
                tgt_pad_mask=passage["attention_mask"],
            )
            
        if q_reps is None or p_reps is None:
            return DenseOutput(
                q_reps=q_reps,
                p_reps=p_reps
            )

        ## *************************************************
        ## dist
        if self.train_args.negatives_x_device:
            q_reps = self.dist_gather_tensor(q_reps)
            p_reps = self.dist_gather_tensor(p_reps)

        ## *************************************************
        ## new script
        scores = torch.matmul(q_reps, p_reps.transpose(0, 1))
        ## *************************************************

        scores = scores.view(q_reps.size(0), -1)

        target = torch.arange(
            scores.size(0),
            device=scores.device,
            dtype=torch.long
        )
        target = target * (p_reps.size(0) // q_reps.size(0))
        ## *************************************************

        loss = self.cross_entropy(scores, target)
            
        ## *************************************************
        return DenseOutput(
            loss=loss,
            scores=scores,
        )
    
    
    def get_gpt2_embeddings(self, input_ids, attention_mask):

        inputs_embeds = self.clm.transformer.wte(input_ids)
        
        position_ids = attention_mask.cumsum(-1)-1
        position_embeds = self.clm.transformer.wpe(position_ids)
        
        hidden_states = inputs_embeds + position_embeds
        return hidden_states
    
    
    def get_gptj_embeddings(self, input_ids, attention_mask):

        inputs_embeds = self.clm.transformer.wte(input_ids)
        hidden_states = inputs_embeds

        return hidden_states
    

    def dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)

        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors
    
    

    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            data_args: DataArguments,
            train_args: TrainingArguments,
            tokenizer: OrderedDict,
            config,
            cache_dir,
            # **hf_kwargs,
    ):
        
        if "gpt-j" in model_args.model_name_or_path:
            clm = GPTJForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
            residual_encoder = GPTJ_Residual_Encoder(
                config=config,
                n_layer=model_args.residual_num_layer,
            )

        else:
            clm = GPT2LMHeadModel.from_pretrained(model_args.model_name_or_path, config=config)
            input_size=clm.get_output_embeddings().in_features
            residual_encoder = GPT2_Residual_Encoder(
                input_size=input_size, 
                n_head=model_args.residual_num_head, 
                n_layer=model_args.residual_num_layer,
            )

            
        if "xx" not in model_args.residual_encoder_name_or_path:
            residual_encoder.load_state_dict(
                torch.load(model_args.residual_encoder_name_or_path)['residual_encoder']
            )
            print("load residual decoder from ", model_args.residual_encoder_name_or_path)


        model = cls(
            clm=clm,
            config=config,
            residual_encoder=residual_encoder,
            model_args=model_args,
            data_args=data_args,
            train_args=train_args,
            tokenizer=tokenizer,
        )
        return model
    
    

    def save(self, output_dir: str):

        ## clm
        if self.clm.transformer.wte.weight.requires_grad:
            self.clm.save_pretrained(output_dir)

        ## residual enc
        ## deepspeed cannot auto convert state_dict to cpu in non-pretrained-model
        ckpt_name = "residual_encoder.ckpt"
        best_ckpt = {
            'residual_encoder': OrderedDict({k: v.cpu() for k, v in self.residual_encoder.state_dict().items()}),
            'ckpt_name': ckpt_name
        }
        # best_ckpt = {
        #     'residual_encoder': self.residual_encoder.state_dict(),
        #     'ckpt_name': ckpt_name
        # }
        torch.save(
            best_ckpt, 
            os.path.join(output_dir, ckpt_name)
        )

        
        

## ********************************************************
## Infer Model
## ********************************************************
class VanillaDenseModelForInference(VanillaDenseModel):

    def __init__(
            self,
            clm: PreTrainedModel,
            residual_encoder: nn.Module,
            config,
            tokenizer: OrderedDict,
            model_args: ModelArguments,
            **kwargs,
    ):
        nn.Module.__init__(self)
        
        self.clm = clm
        self.config = config
        self.model_args = model_args
        self.residual_encoder = residual_encoder
        self.pad_token_id = tokenizer.pad_token_id
        self.cosine_scale = model_args.cosine_scale
        

    @torch.no_grad()
    def get_gpt2_embeddings(self, input_ids, attention_mask):
        return super(VanillaDenseModelForInference, self).get_gpt2_embeddings(input_ids, attention_mask)
    
    @torch.no_grad()
    def get_gptj_embeddings(self, input_ids, attention_mask):
        return super(VanillaDenseModelForInference, self).get_gptj_embeddings(input_ids, attention_mask)

    def forward(
            self,
            text: Dict[str, Tensor] = None,
    ):
            
        if "gpt-j" in self.model_args.model_name_or_path:
            embs = self.get_gptj_embeddings(text["input_ids"], text["attention_mask"])
            
        else:
            embs = self.get_gpt2_embeddings(text["input_ids"], text["attention_mask"])
            
        reps = self.residual_encoder(
            tgt=embs,
            tgt_pad_mask=text["attention_mask"],
        )
        
        ## *************************************************
        ## new script
        if self.cosine_scale is not None:
            reps = F.normalize(reps, p=2, dim=-1)

        return reps

            
    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            data_args: DataArguments,
            tokenizer: OrderedDict,
            config,
            cache_dir,
            # **hf_kwargs,
    ):
        
        
        
        if "gpt-j" in model_args.model_name_or_path:
            clm = GPTJForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
            residual_encoder = GPTJ_Residual_Encoder(
                config=config,
                n_layer=model_args.residual_num_layer,
            )

        else:
            clm = GPT2LMHeadModel.from_pretrained(model_args.model_name_or_path, config=config)
            input_size=clm.get_output_embeddings().in_features
            residual_encoder = GPT2_Residual_Encoder(
                input_size=input_size, 
                n_head=model_args.residual_num_head, 
                n_layer=model_args.residual_num_layer,
            )

        
        residual_encoder.load_state_dict(
            torch.load(model_args.residual_encoder_name_or_path)['residual_encoder']
        )
        
        residual_encoder.to(clm.device)
        print("load residual encoder from ", model_args.residual_encoder_name_or_path)
        
        model = cls(
            clm=clm,
            residual_encoder=residual_encoder,
            config=config,
            tokenizer=tokenizer,
            model_args=model_args,
        )
        return model