from abc import ABC, abstractmethod
from .utils import *
from .cache import DynamicCache
from transformers import AutoTokenizer, AutoModelForCausalLM
import einops
import torch
import json

class Models(ABC):
    def __init__(self):
        self.SYSTEM_PROMPT = ""
        self._tokenizer = None
        self._model_chat = None
        self._model_pretrained = None
        self._model_pretrained_gpu = None
        self.sep = "\n\n"

    @abstractmethod
    def get_normalized_unembedding(self):
        """
        Idea:
        with torch.no_grad():
            rms_norm_weight = g.model_chat.model.norm.weight
            unembedding_matrix = g.model_chat.lm_head.weight
            normalized_unembedding = einops.einsum(unembedding_matrix, rms_norm_weight, "N H, H -> N H")
        """

        pass

    @property
    def tokenizer(self):
        pass

    @property
    def model_finetuned(self):
        pass
    
    @property
    def model_pretrained(self):
        pass

    @property
    def itos(self):
        return {
            tok: self.tokenizer.decode(tok)
            for tok in range(len(self.tokenizer))
        }

    def load_finetuned_and_pretrained_to_gpu(self):
        self.load_gpu("finetuned", device_map="balanced")
        # TODO remove hardcoding
        self._model_pretrained_gpu = AutoModelForCausalLM.from_pretrained(
            self.model_names["pretrained"], 
            device_map="balanced", 
            torch_dtype=torch.float16
        )


    @abstractmethod
    def load_gpu(self, name):
        """
        Idea:
        make_interpolation(None, None, <corresponding model>, dtype=dtype)
        """
        pass

    @abstractmethod
    def _convert_formatted_dialogue(self, formatted_dialogue, **kwargs):
        pass

    def _default_convert_formatted_dialogue(self, formatted_dialogue, force_no_sysprompt=False, strip=True, **kwargs):
        if strip:
            out = [f"{d['role']}: {d['message'].strip()}" for d in formatted_dialogue]
        else:
            out = [f"{d['role']}: {d['message']}" for d in formatted_dialogue]

        if formatted_dialogue[-1]["role"] == "assistant":
            out = out[:-1]
            out.append("assistant: ")

        final_out = "\n\n".join(out)

        if self.SYSTEM_PROMPT != "" and self.args.use_sysprompt and not force_no_sysprompt:
            final_out = f"system: {self.SYSTEM_PROMPT}\n\n" + final_out
        
        return final_out

    def _get_prompt_in_template(self, prompt, **kwargs):
        return self._convert_formatted_dialogue([
            {"role":"user", "message": prompt}, 
            {"role":"assistant", "message": ""}
        ], **kwargs)

    @torch.no_grad()
    def batched_inference(self, texts, debug=False, use_pretrained_gpu=False, **kwargs):
        try:
            clean()
            assert type(texts) == list, "Need list of texts for batched inference"
            tokens = self.tokenizer(texts, padding=True, return_tensors="pt")

            if not use_pretrained_gpu:
                out = self.model_in_gpu(
                    input_ids = tokens.input_ids.to("cuda"), 
                    attention_mask = tokens.attention_mask.to("cuda"),
                    **kwargs
                )
            else:
                # assert self._model_pretrained_gpu is not None
                out = self.model_pretrained_gpu(
                    input_ids = tokens.input_ids.to("cuda"), 
                    attention_mask = tokens.attention_mask.to("cuda"),
                    **kwargs
                )

            return out
        
        finally:
            clean()

    @torch.no_grad()
    def batched_completions_logprobs(self, contexts, completions, debug=False, use_pretrained_gpu=False, **kwargs):
        try:
            clean()
            # assert type(contextstexts) == list, "Need list of texts for batched inference"
            context_ids = self.tokenizer(contexts, return_tensors='np', padding=False).input_ids
            completion_ids = self.tokenizer(completions, return_tensors='np', add_special_tokens=False, padding=False).input_ids
            input_ids = [np.concatenate([context, completion]) for context, completion in zip(context_ids, completion_ids)]

            input_ids, attention_mask = pad(input_ids, self.tokenizer.pad_token_id)

            if not use_pretrained_gpu:
                out = self.model_in_gpu(
                    input_ids = input_ids.to("cuda"), 
                    attention_mask = attention_mask.to("cuda"),
                    **kwargs
                )
            else:
                # assert self._model_pretrained_gpu is not None
                out = self.model_pretrained_gpu(
                    input_ids = input_ids.to("cuda"), 
                    attention_mask = attention_mask.to("cuda"),
                    **kwargs
                )

            logits = out.logits.cpu()
            avg_logprobs = []
            for i, completion_tokens in enumerate(completion_ids):
                completion_tokens = torch.tensor(completion_tokens)
                logits_completion = logits[i, -len(completion_tokens):, :]
                logprobs_completion = torch.log_softmax(logits_completion, dim=-1)
                logprobs_of_tokens = logprobs_completion.gather(-1, completion_tokens.unsqueeze(-1)).squeeze(-1)
                avg_logprobs.append(logprobs_of_tokens.mean().item())

            return avg_logprobs
        
        finally:
            clean()

    def generate(self, prompt, apply_formatting=True, **kwargs):
        if apply_formatting:
            prompt = self._get_prompt_in_template(prompt)
        tokens = self.tokenizer(prompt, return_tensors='pt')
        
        out = self.model_in_gpu.generate(
            input_ids = tokens.input_ids.to("cuda"), 
            **kwargs
        )
        out = self.tokenizer.batch_decode(out)[0]
        return out
    
    def generate_batch(self, formatted_prompt, use_pretrained_gpu=False, get_additional_kwargs=None, **kwargs):
        tokens = self.tokenizer(formatted_prompt, padding=True, truncation=False, return_tensors='pt')

        if get_additional_kwargs is not None:
            additional_kwargs = get_additional_kwargs(tokens.input_ids)
        else:
            additional_kwargs = {}

        if not use_pretrained_gpu:
            out = self.model_in_gpu.generate(
                input_ids = tokens.input_ids.to("cuda"), 
                attention_mask = tokens.attention_mask.to("cuda"),
                **kwargs, **additional_kwargs
            )
        else:
            out = self.model_pretrained_gpu.generate(
                input_ids = tokens.input_ids.to("cuda"), 
                attention_mask = tokens.attention_mask.to("cuda"),
                **kwargs, **additional_kwargs
            )
        return out

    def get_callback_shrink(self, alpha, use_timer=False, single_gpu=False):
        past_key_values = DynamicCache()
        @torch.no_grad()
        def callback_shrink(n_layer, inputs, layer_outputs, device):
            timer = Timer(active=use_timer)

            # device_orig = inputs["hidden_states"].device
            device = inputs["hidden_states"].device

            if single_gpu:
                pretrained_layer = self.model_pretrained_gpu.model.layers[n_layer]
            else:
                pretrained_layer = self.model_pretrained.model.layers[n_layer] # dtype=torch.float16
                timer.checkpoint(f"move layer to {device}")
                pretrained_layer.to(device=device, dtype=torch.float16)
            # device = inputs["hidden_states"].device
            # device = "cuda" #pretrained_layer.mlp.gate_proj.weight.device

            inputs["past_key_value"] = past_key_values if inputs["use_cache"] else None
            inputs["output_attentions"] = False

            past_seen_tokens = 0
            if inputs["use_cache"]:  # kept for BC (cache positions)
                past_seen_tokens = past_key_values.get_seq_length()

            if "cache_position" in inputs.keys() and inputs["cache_position"] is None:
                inputs["cache_position"] = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs["hidden_states"].shape[1], device=device
                )

            timer.checkpoint(f"move inputs to {device}")
            for k,v in inputs.items():
                if hasattr(v, "device"):
                    inputs[k] = v.to(device)

            timer.checkpoint(f"compute layer output")

            pretrained_layer_outputs = pretrained_layer(**inputs)

            timer.checkpoint(f"cast layer outs")
            input_hidden_states = inputs["hidden_states"]
            finetuned_out = layer_outputs[0].to(input_hidden_states)
            pretrained_out = pretrained_layer_outputs[0].to(input_hidden_states)

            timer.checkpoint(f"compute next hidden state")

            # interpolated_out = input_hidden_states + delta_pretrained + alpha * delta_finetuned
            interpolated_out = (1-alpha) * pretrained_out + alpha * finetuned_out
            next_decoder_cache_finetuned = layer_outputs[2 if inputs["output_attentions"] else 1] if len(layer_outputs) > 1 else None

            timer.checkpoint(f"move layer to cpu")

            if not single_gpu:
                pretrained_layer.to("cpu", non_blocking=True)
            # del pretrained_layer
            timer.checkpoint("end")

            return interpolated_out, next_decoder_cache_finetuned

        return callback_shrink
    
    def weight_space_interpolation(self, alpha, params="layers", dtype = torch.float16):
        if self.model_in_gpu is None:
            self.load_gpu("finetuned")
            
        if params == "all":
            model_pretrained_params = self.model_pretrained.parameters()
            model_finetuned_params = self.model_finetuned.parameters()
            model_in_gpu_params = self.model_in_gpu.named_parameters()
        elif params == "layers":
            model_pretrained_params = self.model_pretrained.model.layers.parameters()
            model_finetuned_params = self.model_finetuned.model.layers.parameters()
            model_in_gpu_params = self.model_in_gpu.model.layers.named_parameters()
        elif params == "lm_head":
            model_pretrained_params = self.model_pretrained.lm_head.parameters()
            model_finetuned_params = self.model_finetuned.lm_head.parameters()
            model_in_gpu_params = self.model_in_gpu.lm_head.named_parameters()
        
        # accelerate.cpu_offload(m3, device_map)
        for p1, p2, (name, p3) in zip(model_pretrained_params, model_finetuned_params, model_in_gpu_params):
            if p1.requires_grad:
                with torch.no_grad():
                    pp1 = p1.clone().to(p3.device)
                    pp2 = p2.clone().to(p3.device)

                    p3.data.copy_(((1-alpha) * pp1 + alpha * pp2).to(dtype=p3.dtype))

        print(f"Interpolated {alpha}")
    

class LlamaModels(Models):
    def __init__(self, args, version=2, force_no_sysprompt=False, load_model_kwargs={"torch_dtype": dtypes_dict["float16"]}):
        super().__init__()

        self.args = args

        self.version = version

        self.force_no_sysprompt=force_no_sysprompt

        if version == 2:
            assert args.n_params in ["7b", "13b", "70b"], "Only n_params available are 7b, 13b, 70b"
            self.model_names = {
                "finetuned": f"meta-llama/Llama-{version}-{args.n_params}-chat-hf",
                "pretrained": f"meta-llama/Llama-{version}-{args.n_params}-hf"
            }
            self.name = f"llama_{args.n_params}"
            # self.sep = "</s><s>"
            self.sep = "\n\n"
        elif version == 3:
            assert args.n_params in ["8B", "70B"], "Only n_params available are 8B, 70B"
            self.model_names = {
                "finetuned": f"meta-llama/Meta-Llama-3-{args.n_params}-Instruct",
                "pretrained": f"meta-llama/Meta-Llama-3-{args.n_params}"
            }
            self.name = f"llama3_{args.n_params}"
            self.sep = "<|eot_id|>"
        
        self.model_in_gpu = None
        if self.args.no_formatting:
            self.name = f"{self.name}_noformatting"

            if self.args.use_sysprompt:
                self.name = f"{self.name}_sysprompt"

        if self.force_no_sysprompt:
            self.name = f"{self.name}_nosys"
        # self.dtype = dtypes_dict[args.dtype]
        self.SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information."""

        self.load_model_kwargs = load_model_kwargs
    
    def load_gpu(self, name, device_map="balanced", force=True):
        if name=="pretrained" and self._model_pretrained_gpu is not None:
            self.model_in_gpu = self._model_pretrained_gpu
        elif force or self.model_in_gpu is None:
            print(f"Loading {name} model")
            self.model_in_gpu =  AutoModelForCausalLM.from_pretrained(
                self.model_names[name], 
                device_map = device_map, 
                # torch_dtype = self.dtype,
                **self.load_model_kwargs
            )

    def clean_loaded_model(self):
        del self.model_in_gpu
        clean()
        self.model_in_gpu = None
        clean()

    def _convert_formatted_dialogue(self, formatted_dialogue, force_no_sysprompt=False, **kwargs):
        if self.args.no_formatting:
            return self._default_convert_formatted_dialogue(formatted_dialogue, force_no_sysprompt=force_no_sysprompt, **kwargs)

        if self.version == 2:
            formatted_dialogue = [dict(x) for x in formatted_dialogue]
            sysprompt = "" if force_no_sysprompt else "<<SYS>>\n"+ self.SYSTEM_PROMPT + "\n<</SYS>>" + "\n\n"
            formatted_dialogue[0]["message"] = sysprompt + formatted_dialogue[0]["message"]
            out = [
                f"[INST] {d['message']} [/INST]" if d["role"] == "user" else d['message']
                for d in formatted_dialogue
            ]

            out = "\n\n".join(out[:-1]) + "\n\n"
        elif self.version == 3:
            formatted_dialogue = [dict(x) for x in formatted_dialogue]
            formatted_dialogue = [{"role": x["role"], "content": x["message"]} for x in formatted_dialogue]
            if not force_no_sysprompt and not self.force_no_sysprompt:
                formatted_dialogue = [
                    {"role": "system", "content": self.SYSTEM_PROMPT},
                    *formatted_dialogue
                ]
            out = self.tokenizer.apply_chat_template(formatted_dialogue, tokenize=False, add_generation_prompt=False, add_bos=False)
            bos = "<|begin_of_text|>"
            eot = "<|eot_id|>"
            if out.startswith(bos):
                out = out[len(bos):]
            if out.endswith(eot):
                out = out[:-len(eot)]

        return out

    def prepend_sysprompt(self, s):
        if self.args.no_formatting:
            return f"system: {self.SYSTEM_PROMPT}\n\n"+s
        elif self.version == 2:
            pref = "[INST]"
            assert s.startswith(pref), f"Got string not starting with {pref}: {s}"
            s = s[len(pref):]
            s = pref + " " + "<<SYS>>\n"+ self.SYSTEM_PROMPT + "\n<</SYS>>" + "\n\n" + s
            return s
        elif self.version == 3 and not self.force_no_sysprompt:
            bos = "<|begin_of_text|>"
            sysprompt = self.tokenizer.apply_chat_template([
                {"role": "system", "content": self.SYSTEM_PROMPT}], 
                tokenize=False, 
                add_generation_prompt=False
            )
            if sysprompt.startswith(bos):
                sysprompt = sysprompt[len(bos):]
            return sysprompt + s
        else:
            return s


    @torch.no_grad()
    def get_normalized_unembedding(self, name):
        model = None
        if name == "finetuned":
            model = self.model_finetuned
        elif name == "pretrained":
            model = self.model_pretrained

        rms_norm_weight = model.model.norm.weight
        unembedding_matrix = model.lm_head.weight
        normalized_unembedding = einops.einsum(unembedding_matrix, rms_norm_weight, "N H, H -> N H")
        return normalized_unembedding

    @property
    def tokenizer(self):
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_names["finetuned"], add_bos_token=True)

            self._tokenizer.padding_side  = "left"
            self._tokenizer.pad_token = self._tokenizer.eos_token
        return self._tokenizer

    @property
    def model_finetuned(self):
        if self._model_chat is None:
            print("loading model chat")
            self._model_chat = AutoModelForCausalLM.from_pretrained(self.model_names["finetuned"], device_map="cpu", torch_dtype=torch.float16)
        return self._model_chat
    
    @property
    def model_pretrained(self):
        if self._model_pretrained is None:
            print("loading model pretrained")
            self._model_pretrained = AutoModelForCausalLM.from_pretrained(self.model_names["pretrained"], device_map="cpu", torch_dtype=torch.float16)
        return self._model_pretrained
    
    @property
    def model_pretrained_gpu(self, device_map="balanced", **kwargs):
        if self._model_pretrained_gpu is None:
            print("loading model pretrained to gpu")
            self._model_pretrained_gpu = AutoModelForCausalLM.from_pretrained(
                self.model_names["pretrained"], 
                device_map=device_map, 
                # torch_dtype=torch.float16, 
                **self.load_model_kwargs
            )
        return self._model_pretrained_gpu
    
class VicunaModels(LlamaModels):
    def __init__(self, args):
        super().__init__(args)
        self.args = args

        assert args.n_params in ["7b", "13b"], "Only n_params available are 7b, 13b"

        self.model_names = {
            "finetuned": f"lmsys/vicuna-{args.n_params}-v1.5",
            "pretrained": f"meta-llama/Llama-2-{args.n_params}-hf"
        }
        self.model_in_gpu = None
        self.name = f"vicunav15_{args.n_params}"

        if self.args.no_formatting:
            self.name = f"{self.name}_noformatting"

            if self.args.use_sysprompt:
                self.name = f"{self.name}_sysprompt"

        # if version != 2:
        #     self.name = f"llama{version}_{args.n_params}"

        self.dtype = dtypes_dict[args.dtype]
        self.SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information."""


class GemmaModels(Models):
    def __init__(self, args, version=""):
        super().__init__()
        self.args = args

        assert args.n_params in ["2b", "7b"], "Only n_params available are 2b, 7b"

        version_str = "" if version == "" else f"{version}-"

        self.model_names = {
            "finetuned": f"google/gemma-{version_str}{args.n_params}-it",
            "pretrained": f"google/gemma-{args.n_params}"
        }
        self.model_in_gpu = None
        self.name = f"gemma{version.replace('.', '')}_{args.n_params}"
        
        if self.args.no_formatting:
            self.name = f"{self.name}_noformatting"
            assert not args.use_sysprompt, "Gemma models don't have a system prompt"

        self.dtype = dtypes_dict[args.dtype]
        self.START = "<start_of_turn>"
        self.END = "<end_of_turn>"

    def load_gpu(self, name, device_map="balanced", force=True):
        if name=="pretrained" and self._model_pretrained_gpu is not None:
            self.model_in_gpu = self._model_pretrained_gpu
        elif force or self.model_in_gpu is None:
            print(f"Loading {name} model")
            self.model_in_gpu =  AutoModelForCausalLM.from_pretrained(
                self.model_names[name], 
                device_map = device_map, 
                torch_dtype = self.dtype
            )

    def prepend_sysprompt(self, s):
        return s

    def _convert_formatted_dialogue(self, formatted_dialogue, **kwargs):
        if self.args.no_formatting:
            return self._default_convert_formatted_dialogue(formatted_dialogue, **kwargs)
        # return json.dumps(formatted_dialogue, indent=4)
        formatted_dialogue = [dict(x) for x in formatted_dialogue]
        entity_map = {"user": "user", "assistant": "model"}
        formatted_dialogue = [
            f"{self.START}{entity_map[d['role']]}\n{d['message'].strip()}{self.END}"
            for d in formatted_dialogue
        ]

        formatted_dialogue[-1] = f"{self.START}model"

        return "\n".join(formatted_dialogue) + "\n"
        # return formatted_dialogue

    def clean_loaded_model(self):
        del self.model_in_gpu
        clean()
        self.model_in_gpu = None
        clean()

    @torch.no_grad()
    def get_normalized_unembedding(self, name):
        model = None
        if name == "finetuned":
            model = self.model_finetuned
        elif name == "pretrained":
            model = self.model_pretrained

        rms_norm_weight = model.model.norm.weight
        unembedding_matrix = model.lm_head.weight
        normalized_unembedding = einops.einsum(unembedding_matrix, rms_norm_weight, "N H, H -> N H")
        return normalized_unembedding

    @property
    def tokenizer(self):
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_names["pretrained"])
            # self._tokenizer.pad_token = self._tokenizer.eos_token
            self._tokenizer.padding_side  = "left"
        return self._tokenizer

    @property
    def model_finetuned(self):
        if self._model_chat is None:
            print("loading model chat")
            self._model_chat = AutoModelForCausalLM.from_pretrained(self.model_names["finetuned"], device_map="cpu", torch_dtype=torch.float16)
        return self._model_chat
    
    @property
    def model_pretrained(self):
        if self._model_pretrained is None:
            print("loading model pretrained")
            self._model_pretrained = AutoModelForCausalLM.from_pretrained(self.model_names["pretrained"], device_map="cpu", torch_dtype=torch.float16)
        return self._model_pretrained
    
    @property
    def model_pretrained_gpu(self, device_map="balanced"):
        if self._model_pretrained_gpu is None:
            print("loading model pretrained to gpu")
            self._model_pretrained_gpu = AutoModelForCausalLM.from_pretrained(self.model_names["pretrained"], device_map=device_map, torch_dtype=torch.float16)
        return self._model_pretrained_gpu
    
class ZephyrGemmaModels(GemmaModels):
    def __init__(self, args, version="v0.1"):
        super().__init__(args)
        self.args = args

        assert args.n_params in ["7b"], "Only n_params available is 7b"

        self.model_names = {
            "finetuned": f"HuggingFaceH4/zephyr-{args.n_params}-gemma-{version}",
            "pretrained": f"google/gemma-{args.n_params}"
        }
        self.model_in_gpu = None
        self.name = f"zephyr_gemma_{args.n_params}_{version.replace('.', '')}"
        
        if self.args.no_formatting:
            self.name = f"{self.name}_noformatting"
            assert not args.use_sysprompt, "Gemma models don't have a system prompt"

        self.dtype = dtypes_dict[args.dtype]
        self.START = "<start_of_turn>"
        self.END = "<end_of_turn>"

class MistralModels(Models):
    def __init__(self, args, version="v0.1"):
        super().__init__()
        self.args = args

        assert args.n_params in ["7B"], "Only n_params available are 7B"

        self.model_names = {
            "finetuned": f"mistralai/Mistral-{args.n_params}-Instruct-{version}",
            "pretrained": f"mistralai/Mistral-{args.n_params}-v0.1"
        }
        self.model_in_gpu = None
        self.name = f"mistral_{args.n_params}_{version.replace('.', '')}"
        self.SYSTEM_PROMPT = "Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity."
        
        if self.args.no_formatting:
            self.name = f"{self.name}_noformatting"

            if self.args.use_sysprompt:
                self.name = f"{self.name}_sysprompt"

        self.dtype = dtypes_dict[args.dtype]

    def load_gpu(self, name, device_map="balanced", force=True):
        if name=="pretrained" and self._model_pretrained_gpu is not None:
            self.model_in_gpu = self._model_pretrained_gpu
        elif force or self.model_in_gpu is None:
            print(f"Loading {name} model")
            self.model_in_gpu =  AutoModelForCausalLM.from_pretrained(
                self.model_names[name], 
                device_map = device_map, 
                torch_dtype = self.dtype
            )

    def clean_loaded_model(self):
        del self.model_in_gpu
        clean()
        self.model_in_gpu = None
        clean()

    # def _default_convert_formatted_dialogue(self, formatted_dialogue, strip=True, force_no_sysprompt=False):
    #     if strip:
    #         out = [f"{d['role']}: {d['message'].strip()}" for d in formatted_dialogue]
    #     else:
    #         out = [f"{d['role']}: {d['message']}" for d in formatted_dialogue]

    #     if formatted_dialogue[-1]["role"] == "assistant":
    #         out = out[:-1]
    #         out.append("assistant: ")

    #     return "\n\n".join(out)
        
    def prepend_sysprompt(self, s):
        if self.args.no_formatting:
            return f"system: {self.SYSTEM_PROMPT}\n\n"+s
        else:
            pref = "[INST]"
            assert s.startswith(pref), f"Got string not starting with {pref}: {s}"
            s = s[len(pref):]
            s = pref + " " + self.SYSTEM_PROMPT + "\n\n" + s
            return s

    def _convert_formatted_dialogue(self, formatted_dialogue, strip=True, force_no_sysprompt=False, **kwargs):
        if self.args.no_formatting:
            out = self._default_convert_formatted_dialogue(formatted_dialogue, strip=strip, force_no_sysprompt=force_no_sysprompt, **kwargs)
            return out

        formatted_dialogue = [dict(x) for x in formatted_dialogue]

        sysprompt = "" if force_no_sysprompt else "<<SYS>>\n"+ self.SYSTEM_PROMPT + "\n<</SYS>>" + "\n\n" 
        formatted_dialogue[0]["message"] = sysprompt + formatted_dialogue[0]["message"]
        out = [
            f"[INST] {d['message']} [/INST]" if d["role"] == "user" else d['message']
            for d in formatted_dialogue
        ]

        return "\n\n".join(out[:-1]) + "\n\n"

    @torch.no_grad()
    def get_normalized_unembedding(self, name):
        model = None
        if name == "finetuned":
            model = self.model_finetuned
        elif name == "pretrained":
            model = self.model_pretrained

        rms_norm_weight = model.model.norm.weight
        unembedding_matrix = model.lm_head.weight
        normalized_unembedding = einops.einsum(unembedding_matrix, rms_norm_weight, "N H, H -> N H")
        return normalized_unembedding

    @property
    def tokenizer(self):
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_names["pretrained"], add_bos_token=True)
            self._tokenizer.pad_token = self._tokenizer.eos_token
            self._tokenizer.padding_side  = "left"
        return self._tokenizer

    @property
    def model_finetuned(self):
        if self._model_chat is None:
            print("loading model chat")
            self._model_chat = AutoModelForCausalLM.from_pretrained(self.model_names["finetuned"], device_map="cpu", torch_dtype=torch.float16)
        return self._model_chat
    
    @property
    def model_pretrained(self):
        if self._model_pretrained is None:
            print("loading model pretrained")
            self._model_pretrained = AutoModelForCausalLM.from_pretrained(self.model_names["pretrained"], device_map="cpu", torch_dtype=torch.float16)
        return self._model_pretrained

    @property
    def model_pretrained_gpu(self, device_map="balanced"):
        if self._model_pretrained_gpu is None:
            print("loading model pretrained to gpu")
            self._model_pretrained_gpu = AutoModelForCausalLM.from_pretrained(self.model_names["pretrained"], device_map=device_map, torch_dtype=torch.float16)
        return self._model_pretrained_gpu



models_classes = {
    "llama": LlamaModels,
    "llama_4bit": lambda args: LlamaModels(args, load_model_kwargs={"load_in_4bit": True}),
    "llama3": lambda args: LlamaModels(args, version=3),
    "llama3_nosys": lambda args: LlamaModels(args, version=3, force_no_sysprompt=True),
    "llama3_4bit": lambda args: LlamaModels(args, version=3, load_model_kwargs={"load_in_4bit": True}),
    "mistralv01": lambda args: MistralModels(args, version="v0.1"),
    "mistralv02": lambda args: MistralModels(args, version="v0.2"),
    "vicuna": VicunaModels,
    # "llama1": lambda args: LlamaModels(args, version=1),
    "gemma": GemmaModels,
    "gemma11": lambda args: GemmaModels(args, version='1.1'),
    "zephyr_gemma_v01": lambda args: ZephyrGemmaModels(args, version="v0.1")
}
