import sys
import torch
from pathlib import Path
from transformers import (
    AutoModel,
    AutoTokenizer,
    AutoConfig
)


sys.path.insert(0, Path(__file__).parent.parent.as_posix())

class AttributeDict(dict):
    def __getattr__(self, name):
        return self[name]

class ModelCompletion:

    def __init__(self, model_path, max_input_length=None, device=None):
        self.model_path = model_path
        if device==None:
            device = "auto"
        if "diffullama" in model_path:
            device = "cuda"
            from transformers import LlamaForCausalLM
            self.model = LlamaForCausalLM.from_pretrained(
                model_path,
                device_map=device,
                _attn_implementation="eager", 
                torch_dtype=torch.bfloat16
            )
        else:
            self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map=device)
        self.max_input_length = max_input_length
        self.extra_tokenizer_config = {}
        if self.max_input_length:
            self.extra_tokenizer_config["model_max_length"] = self.max_input_length
            self.extra_tokenizer_config["truncation_side"] = "left"
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, **self.extra_tokenizer_config)

        #  DiffuLLama model
        if "diffullama" in model_path:
            sys.path.insert(0, (Path(__file__).parent.parent/'diffusion_prediction'/'DiffuLLaMA').as_posix())
            from diffusion_prediction.DiffuLLaMA.model import DiscreteDiffusionModel
            config = AutoConfig.from_pretrained(model_path)
            self.model = DiscreteDiffusionModel(
                model=self.model,
                config=config,
                tokenizer=self.tokenizer,
                device=device
            )
        self.model = self.model.eval()
        self.model_type = self.model.config.model_type
        assert (self.model_type in ["Dream", "llada"]) or \
            "diffullama" in model_path

    def to(self, device):
        self.model.to(device)

    def complete_dream(self, chats, **kwargs):
        inputs = self.tokenizer.apply_chat_template(
            chats, return_tensors="pt", return_dict=True, add_generation_prompt=True, padding=True,
            truncation=self.max_input_length!=None, max_length=self.max_input_length
        )
        input_ids = inputs.input_ids.to(self.model.device)
        attention_mask = inputs.attention_mask.to(self.model.device)

        if "gen_length" in kwargs.keys():
            kwargs["max_new_tokens"] = kwargs["gen_length"]
        parameters_template = {
            "max_new_tokens": 512,
            "output_history": True,
            "return_dict_in_generate": True,
            "steps": 512,
            "temperature": 0.2,
            "top_p": 0.95,
            "alg": "entropy",
            "alg_temp": 0.
        }
        for k,v in kwargs.items():
            if k in parameters_template.keys():
                parameters_template[k] = v

        raw_output = self.model.diffusion_generate(
            input_ids,
            attention_mask=attention_mask,
            **parameters_template
        )
        generations = [
            [self.tokenizer.decode(g[len(p) :].tolist()).split(self.tokenizer.eos_token)[0]]
            for p, g in zip(input_ids, raw_output.sequences)
        ]

        return generations, raw_output

    def complete_llada(self, chats, **kwargs):
        # assert len(chats)==1, "Batch Size for llada should be 1 currently."
        raw_output = []
        for chat in chats:
            from diffusion_prediction.LLaDA.generate import generate
            inputs = self.tokenizer.apply_chat_template(
                chat, return_tensors="pt", return_dict=True, add_generation_prompt=True, padding=True,
                truncation=self.max_input_length!=None, max_length=self.max_input_length
            )
            input_ids = inputs.input_ids.to(self.model.device)
            parameters_template = {
                "steps" : 128,
                "gen_length" : 128,
                "block_length" : 32,
                "temperature" : 0.,
                "cfg_scale" : 0.,
                "remasking" : 'low_confidence'
            }

            for k,v in kwargs.items():
                if k in parameters_template.keys():
                    parameters_template[k] = v

            raw_output.append(generate(self.model, input_ids, **parameters_template))

        generations = [self.tokenizer.batch_decode(item[:, input_ids.shape[1]:], skip_special_tokens=True) for item in raw_output]

        return generations, raw_output

    # https://github.com/HKUNLP/DiffuLLaMA/blob/60abd1e372c62aabf1a070949c95988cb576cabd/evaluation/eval-diffullama.py#L313
    def complete_diffullama(self, chats, **kwargs):
        # assert len(chats)==1, "Batch Size for llada should be 1 currently."
        generations, raw_output = [], []
        for chat in chats:
            from diffusion_prediction.DiffuLLaMA.model import generate_samples
            # https://github.com/HKUNLP/DiffuLLaMA/blob/60abd1e372c62aabf1a070949c95988cb576cabd/evaluation/eval-diffullama.py#L361
            parameters_template = {
                "shift": True,
                "diffusion_steps": 32,
                "logits_temp": 0.9,
                "topp_temp": 0.9,
                "verbose": False,
                "flash_attn": "eager"
            }
            map_kwargs_parameters_template = {
                "steps": "diffusion_steps",
                "temperature": "logits_temp",
            }
            for k,v in map_kwargs_parameters_template.items():
                if k in kwargs.keys():
                    kwargs[v] = kwargs[k]

            for k,v in kwargs.items():
                if k in parameters_template.keys():
                    parameters_template[k] = v

            input_ids = self.tokenizer.encode(chat[0]["content"])
            x0 = input_ids + [0] * (kwargs["gen_length"])
            src_mask = [1] * len(input_ids) + [0] * ((kwargs["gen_length"]))
            # https://github.com/HKUNLP/DiffuLLaMA/blob/60abd1e372c62aabf1a070949c95988cb576cabd/evaluation/eval-diffullama.py#L336
            # seems the steps in diffullama is same with the gen_length ?
            model_inputs = {"input_ids": torch.tensor([x0], device=self.model.device), "src_mask": torch.tensor([src_mask], device=self.model.device)}
            res = generate_samples(self.model, AttributeDict(parameters_template), self.tokenizer, model_inputs)
            pred = self.tokenizer.decode(res.tolist()[0][len(input_ids)-1:])

            raw_output.append(None)
            generations.append(pred)

        return generations, raw_output

    def complete(self, chats, **kwargs):
        if self.model_type == "Dream":
            return self.complete_dream(chats, **kwargs)
        elif self.model_type == "llada":
            return self.complete_llada(chats, **kwargs)
        elif "diffullama" in self.model_path:
            return self.complete_diffullama(chats, **kwargs)
        else:
            return NotImplementedError()
