import torch
from .constants import MODEL_LLM
from collections import defaultdict
class LogicEngine:
    # Model Config
    llm_name = "llama-v2-7b"
    model_config = None 

    # Flag
    logic_flag = {}

    # Ablation
    ablation_config ={
        "sink": False,
        "head": False,
    }
    rule_config = {
        "sink_rule": "ours",
        "head_rule": "ours",
    }

    # Features
    hidden_state = []
    save_to_path = {}

    # Metadata, ValueMonitor
    qid = None
    output_token_count = 0
    current_decoder_layer = 0
    sink_select_layers = -float("inf")
    gt_label = None
    answer = None
    answer_ids = None
    prompt = None
    correct = None
    id_pieces = {
        "system": [],
        "role_0": [],
        "image": [],
        "inst_q": [],
        "role_1": [],
    }
    text_pieces = {
        "system": [],
        "role_0": [],
        "image": [],
        "inst_q": [],
        "role_1": [],
    }
    begin_pos = {
        "system": float("-inf"),
        "role_0": float("-inf"),
        "image": float("-inf"),
        "inst_q": float("-inf"),
        "role_1": float("-inf"),
    }
    vis_len = 0
    image_path = None

    # vision data
    after_ve = None
    after_proj = None

    # DimProspector
    llm_name = "llama-7b"
    indices = {}

    # HeadFork
    forked_head = {}
    forked_head_per_token = defaultdict(dict)

    @classmethod
    def set_sink_select_layers(cls, layer="all"):
        if layer == "all":
            cls.sink_select_layers = [i for i in range(cls.model_config.num_hidden_layers)]
        else:
            assert isinstance(layer, int)
            cls.sink_select_layers = layer
            
    @classmethod
    def rule_setting(cls, config):
        for key, value in config.items():
            cls.rule_config[key] = value
            print(f"Logic rule :: {key} set to {value}")

    @classmethod
    def set_llm_name(cls, model_name):
        for body, llm in MODEL_LLM.items():
            if model_name in body:
                cls.llm_name = llm
                break

    @classmethod
    def export_model_config(cls, config):
        cls.model_config = config

    @classmethod
    def set_flag(cls, flag_value=True):
        cls.logic_flag[cls.__name__] = flag_value

    @classmethod
    def _flag(cls, name=None):
        if name is not None:
            return cls.logic_flag.get(name, False)
        return cls.logic_flag.get(cls.__name__, False)

    @classmethod
    def set_save_to_path(cls, save_to_path):
        cls.save_to_path[cls.__name__] = save_to_path

    @classmethod
    def get_save_to_path(cls):
        return cls.save_to_path.get(cls.__name__, None)

    @classmethod
    def run_logic(cls):
        raise NotImplementedError("Running basic logic in LogicEngine.")

    @classmethod
    def clear(cls):
        cls.hidden_state = []
        cls.save_to_path = {}
        cls.indices = {}
        cls.forked_head = {}
        cls.forked_head_per_token = defaultdict(dict)
        cls.qid = None
        cls.output_token_count = 0
        cls.current_decoder_layer = 0
        cls.gt_label = None
        cls.answer = None
        cls.answer_ids = None
        cls.prompt = None
        cls.correct = None
        cls.image_path = None
        cls.after_ve = None
        cls.after_proj = None
        cls.id_pieces = {
            "system": [],
            "role_0": [],
            "image": [],
            "inst_q": [],
            "role_1": [],
        }
        cls.text_pieces = {
            "system": [],
            "role_0": [],
            "image": [],
            "inst_q": [],
            "role_1": [],
        }
        cls.begin_pos = {
            "system": float("-inf"),
            "role_0": float("-inf"),
            "image": float("-inf"),
            "inst_q": float("-inf"),
            "role_1": float("-inf"),
        }
