import os
from typing import List, Union
import numpy as np
import torch
from torch import Tensor, nn
from transformers import AutoTokenizer
from .diffloss import DiffLoss

class MotionUndHead(nn.Module):
    pass

class MLM(nn.Module):
    def __init__(
        self,
        model_path: str,
        model_type: str = "gpt2",
        max_length: int = 256,
        motion_holder_repeat=None,
        output_motion_holder_seq=None,
        input_motion_holder_seq=None,
        diff_kwargs=None,
    ) -> None:
        super().__init__()
        self.model_type = model_type
        self.max_length = max_length

        self.motion_holder_repeat = motion_holder_repeat
        self.multi_hidden=True

        self.output_motion_holder_seq = output_motion_holder_seq
        self.input_motion_holder_seq = input_motion_holder_seq
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True)

        from transformers.models.gpt2.modeling_gpt2 import GPT2Config
        from mot_code.mot_example_gpt2 import MoTGPT2LMHeadModel
        self.language_model = MoTGPT2LMHeadModel.from_pretrained(model_path)
        self.motion_und_head = MotionUndHead(self.vae_latent_dim, self.motion_input_dim, projector_type='linear')
        self.diffloss = DiffLoss(**diff_kwargs)
        self.fake_latent = nn.Parameter(torch.zeros(self.motion_holder_repeat, self.llm_decoder_embed_dim))#.requires_grad_(False)

    
    def process_type_ids(self, labels_input_ids):
        pass
    
    def forward(self, 
                inputs: List[str], 
                motion_tokens_input: List[Tensor],
                motion_tokens_output: List[Tensor],
                ):

        inputs = self.tokenizer(inputs,
                                padding='max_length',
                                max_length=self.max_length,
                                truncation=True,
                                return_attention_mask=True,
                                return_tensors="pt")

        device = motion_tokens_input[0].device
        labels_input_ids = inputs.input_ids.to(device)
        lables_attention_mask = inputs.attention_mask.to(device)
        type_ids, input_is_motion, output_is_motion = self.process_type_ids(labels_input_ids).to(device)
        
        inputs_embeds = self.language_model.get_embeddings_from_ids(labels_input_ids, type_ids)
        inputs_embeds[self.mod_id][input_is_motion, :] = self.motion_und_head(motion_tokens_input)

        labels = labels_input_ids.clone().detach()
        outputs = self.language_model(
            type_ids=type_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=lables_attention_mask,
            labels=labels,
            output_hidden_states=True, 
            )

        hidden = outputs.hidden_states[-1][self.mod_id]
        hidden_to_diff = torch.cat(hidden[output_is_motion])
        outputs.diff_loss = self.diffloss.forward_diff_loss(z=hidden_to_diff, target=motion_tokens_output)

        return outputs

    def generate(self,
                texts: List[str],
                motion_tokens_input=None,
                max_length: int = 256,
                num_beams: int = 1,
                do_sample: bool = True,
                gen_mode:Union[str, List[str]] = 'text'
                ):

        if self.lm_type == 'dec':
            texts = [text + " \n " for text in texts]

        source_encoding = self.tokenizer(texts,
                                         padding='max_length',
                                         max_length=self.max_length,
                                         truncation=True,
                                         return_attention_mask=True,
                                         add_special_tokens=True,
                                         return_tensors="pt")

        source_input_ids = source_encoding.input_ids.to(self.device)
        source_attention_mask = source_encoding.attention_mask.to(self.device)
        type_ids, input_is_motion, output_is_motion = self.process_type_ids(source_input_ids).to(device)
        
        inputs_embeds = self.language_model.get_embeddings_from_ids(source_input_ids, type_ids)
        inputs_embeds[self.mod_id][input_is_motion, :] = self.motion_und_head(motion_tokens_input)

        with torch.no_grad():
            outputs = self.language_model.generate(
                type_ids=type_ids,
                inputs_embeds=inputs_embeds,
                attention_mask=source_attention_mask,
                pad_token_id=self.tokenizer.pad_token_id,
                do_sample=do_sample,
                max_new_tokens=max_length,
                num_beams=num_beams,
                mode=gen_mode,
                return_dict=True,
                output_hidden_states=True,
            )

        outputs_string = self.tokenizer.batch_decode(outputs.sequence, skip_special_tokens=False)
        cleaned_text = self.clean_text_string(outputs_string)

        hidden = outputs.hidden_states[-1][self.mod_id]#[:,:-1,:]  # [bs, labels_seq_len, emb_dim(768)] 
        motion_latents = hidden[output_is_motion]

        return motion_latents, cleaned_text

