import os
import torch
from torch import nn
from einops import rearrange
from transformers import (
    GPT2LMHeadModel, 
    GPT2Model,
    AutoTokenizer
)
from typing import Optional, Tuple, List, Union, Callable
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2MLP#, eager_attention_forward
# from transformers.integrations.sdpa_attention import sdpa_attention_forward
# GPT2SdpaAttention

from transformers.utils import logging
logger = logging.get_logger(__name__)

import inspect
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList

from transformers.cache_utils import StaticCache, Cache
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils import is_torchdynamo_compiling

from .my_modeling_gpt2 import MoTGPT2Model
from .modality_utils import get_modalities_infos
from .mot_module import get_embeds_from_ids


import random
import numpy as np
def seed_setting(seed):
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class MoTGPT2LMHeadModel(GPT2LMHeadModel):

    def __init__(self, config, modality_num=2, motion_codebook_size=512+4):
        super().__init__(config)
        self.modality_num = modality_num
        self.forward_mod = 0  # 'text'
        self.config.d_model = config.n_embd
        self.config.motion_vocab_size = motion_codebook_size
        self.config.text_vocab_size = config.text_vocab_size
        self.modality_infos = None
        self.last_pos_ids = None

        self.transformer = MoTGPT2Model(config, modality_num=modality_num)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.post_init()
        
    def set_modality_info(self, tokenizer):
        modality_infos = get_modalities_infos(config=self.config, tokenizer=tokenizer)
        self.mod2id = {m.modality_name:i for i,m in enumerate(modality_infos)}
        self.text_id = self.mod2id['text']
        modality_infos[self.text_id].pre_processor = self.transformer.wte
        modality_infos[self.text_id].post_processor = self.lm_head
        self.modality_infos = modality_infos
        self.pad_ids = [m.pad_id for m in modality_infos]

        pre_processors = ([m.pre_processor for m in self.modality_infos])
        post_processors = ([m.post_processor for m in self.modality_infos])
        self.pre_processors = nn.ModuleList(pre_processors)
        self.post_processors = nn.ModuleList(post_processors)

    def update_typeids(self, type_ids):
        self.valid_pos = type_ids
        self.transformer.update_typeids(type_ids)
    
    def get_embeddings_from_ids(self, input_ids, type_ids):
        inputs_embeds = []
        for i in range(self.modality_num):
            pass
        return inputs_embeds

    def typeids_2_valpos(self, type_ids):
        pass
        return valid_pos
    
    def forward(
        self, 
        type_ids: torch.Tensor = None,
        input_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        forward_mod: np.ndarray[int] = None,
    ):
        self.type_ids = type_ids
        valid_pos = self.typeids_2_valpos(type_ids)
        self.update_typeids(valid_pos)
        
        if forward_mod is None:
            forward_mod = self.forward_mod

        if inputs_embeds is None:
            inputs_embeds = get_embeds_from_ids(input_ids, self.valid_pos, self.pad_ids, self.pre_processors)
            input_ids = None

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hidden_states = transformer_outputs[0]
        
        lm_logits = self.post_processors(hidden_states)

        total_loss = None
        if labels is not None:
            total_loss = 0.
            if not isinstance(labels, list):
                mlabels = [labels]*self.modality_num
            assert len(mlabels) == len(lm_logits), 'labels not match modalities'
            for i, out_logit in enumerate(lm_logits):
                mlabel = mlabels[i].to(out_logit.device)
                loss_fct = self.modality_infos[i].loss_fct
                loss_mod = loss_fct(out_logit, mlabel, self.valid_pos[i])
                total_loss = total_loss + (loss_mod).nan_to_num_(0)
        
        if not return_dict:
            output = (lm_logits, ) + transformer_outputs[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=total_loss,
            # logits=pred,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
    

if __name__ == "__main__":

    model_config = "deps/gpt2-medium"
    tokenizer = AutoTokenizer.from_pretrained(model_config)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    model = GPT2LMHeadModel.from_pretrained(model_config).eval()
    model.transformer._attn_implementation = "sdpa"

    model2 = MoTGPT2LMHeadModel(model.config).eval()
    model2.generation_config = model.generation_config

    new_state_dict = torch.load('deps/mot-gpt2-medium/model_state_dict.pth')
    msg = model2.load_state_dict(new_state_dict, strict=False)

    tokenizer.add_tokens([f'<motion_id_{i}>' for i in range(512)])
    tokenizer.add_special_tokens({'additional_special_tokens': ['<start_of_motion>', '<end_of_motion>', '<masked_motion>', '<pad_motion>']})
    model2.config.motion_codebook_size = model2.config.motion_codebook_size
    model2.set_modality_info(tokenizer)

    inputs = tokenizer(["How are you today"], return_tensors="pt")
    _pre = True
    if _pre:
        inputs1 = tokenizer([" A A AHow are you today"], return_tensors="pt")
    else:
        inputs1 = tokenizer(["How are you today A A A"], return_tensors="pt")
    
    generation_config = dict(
        max_new_tokens=20,
        do_sample=True,
        num_return_sequences=1,
        pad_token_id=50256,
    )

    input_ids = inputs1.input_ids
    inputs_embeds = model2.transformer.wte(inputs1.input_ids)
    type_ids = torch.zeros_like(inputs1.input_ids)
    position_ids = torch.arange(inputs1.input_ids.shape[-1]).unsqueeze(0).long()

    torch.manual_seed(0)
    outputs1 = model.generate(input_ids=input_ids, attention_mask=inputs1.attention_mask, 
                            **generation_config)
    print('inputs', tokenizer.batch_decode(outputs1, skip_special_tokens=True))

    torch.manual_seed(0)
    # input_ids=input_ids
    outputs2 = model2.generate(input_ids=input_ids, attention_mask=inputs1.attention_mask,
                            type_ids=type_ids, # position_ids=position_ids, 
                            **generation_config)
    print('mine inputs1', tokenizer.batch_decode(outputs2, skip_special_tokens=True))
