from collections import defaultdict
import re
import os.path as osp
from typing import List
import torch
import gc
from logic.logic_engine import LogicEngine
import pickle

class ValueMonitor(LogicEngine):

    @classmethod
    def activate(cls):
        cls.set_flag(True)
        print(f"{cls.__name__} flag set to {cls.get_flag()}")

    @classmethod
    def remember_layer(cls, l):
        cls.__base__.current_decoder_layer = l

    @classmethod
    def count_output_token(cls):
        cls.__base__.output_token_count += 1

    @classmethod
    def remember_qid(cls, qid):
        cls.__base__.qid = qid
        
    @classmethod
    def recall_layer(cls):
        return cls.current_decoder_layer

    @classmethod
    def recall_output_token_counts(cls):
        return cls.output_token_count

    @classmethod
    def recall_qid(cls):
        return cls.qid

    @classmethod
    def clear(cls):
        cls.__base__.output_token_count = 0
        cls.__base__.current_decoder_layer = 0
        torch.cuda.empty_cache()
        gc.collect()

class FeatureSaver(LogicEngine):

    @classmethod
    def get_feat(cls):
        results = {
            "hs": cls.hidden_state,
        }
        return results

    @classmethod
    def set_hidden_state(cls, hs):
        assert isinstance(hs, torch.Tensor), "Hidden state must be a Tensor."  # torch.Size([1, 632, 4096])
        cls.__base__.hidden_state.append(hs.cpu().clone())

    @classmethod
    def save_hs(cls, save_to_path):
        # save pickle
        if osp.exists(save_to_path):
            print(f"File already exists at {save_to_path}.")
            return
        hs = cls.get_feat()
        with open(save_to_path, "wb") as f:
            pickle.dump(hs, f)
        # torch.save(hs, save_to_path)

class AttentionSaver(LogicEngine):
    @classmethod
    def activate(cls):
        # 자식 클래스에서 자신에 대한 flag를 설정
        cls.set_flag(True)
        print(f"{cls.__name__} flag set to {cls.get_flag()}")

    @classmethod
    def save_attn_weights(cls, attns, save_to_path):
        # save pickle
        if osp.exists(save_to_path):
            print(f"File already exists at {save_to_path}.")
            return
        with open(save_to_path, "wb") as f:
            pickle.dump(attns, f)
        # torch.save(attns, save_to_path)

class TweakMetadataSaver(LogicEngine):
    correct = defaultdict(int)
    answer_id = defaultdict(list)
    answer = defaultdict(str)
    save_to_path = None
    
    @classmethod
    def pull_answer(cls):
        return cls.answer_id
    
    @classmethod
    def push_answer(cls, key, answer_ids):
        cls.answer_id[key].append(answer_ids)

    @classmethod
    def set_correct(cls, key, correct):
        cls.correct[key] = correct

    @classmethod
    def set_answer(cls, key, answer):
        cls.answer[key] = answer

    @classmethod
    def get_metadata(cls, key):
        results = {
            "correct": cls.correct[key],
            "answer_id": cls.answer_id[key],
            "answer": cls.answer[key],
        }
        return results

class MetadataSaver(LogicEngine):

    @classmethod
    def activate(cls):
        cls.set_flag(True)
        print(f"{cls.__name__} flag set to {cls.get_flag()}")

    @classmethod
    def set_image_path(cls, image_path):
        cls.__base__.image_path = image_path

    @classmethod
    def set_qid(cls, qid):
        cls.__base__.qid = qid

    @classmethod
    def get_metadata(cls):
        """
        Returns the metadata of the class.
        """
        indices_for_save = []
        for key, indices in cls.indices.items():
            indices_for_save.append(indices.cpu().clone())

        forked_head_for_save = defaultdict(dict)
        for tok, forked_head in cls.forked_head_per_token.items():
            for layer, head in forked_head.items():
                if isinstance(head, torch.Tensor):
                    forked_head_for_save[tok][layer] = head.cpu().clone()
                else:
                    forked_head_for_save[tok][layer] = head

        return {
            "prompt": cls.prompt,
            "gt_label": cls.gt_label,
            "answer": cls.answer,
            "answer_ids": cls.answer_ids,
            "correct": cls.correct,
            "id_pieces": cls.id_pieces,
            "text_pieces": cls.text_pieces,
            "begin_pos": cls.begin_pos,
            "vis_len": cls.vis_len,
            "image_path": cls.image_path,
            "sink": indices_for_save,
            "c_head": forked_head_for_save,
        }

    @classmethod
    def set_correct(cls, correct):
        cls.__base__.correct = correct

    @classmethod
    def set_prompt(cls, prompt):
        cls.__base__.prompt = prompt

    @classmethod
    def set_answer(cls, answer_ids, answer):
        cls.__base__.answer_ids = answer_ids
        cls.__base__.answer = answer

    @classmethod
    def set_gt_label(cls, gt_label):
        cls.__base__.gt_label = gt_label

    @classmethod
    def set_vis_len(cls, vis_len):
        cls.__base__.vis_len = vis_len

    @classmethod
    def set_begin_pos(cls, key, idx):
        cls.__base__.begin_pos[key] = idx + cls.vis_len - 1 if key in ["inst_q", "role_1"] else idx

    @classmethod
    def set_id_pieces(cls, key, ids: List[int]):
        assert isinstance(ids, list)
        cls.__base__.id_pieces[key] = ids * cls.vis_len if key == "image" else ids

    @classmethod
    def set_text_pieces(cls, key, texts: List[str]):
        assert isinstance(texts, list)
        cls.__base__.text_pieces[key] = texts * cls.vis_len if key == "image" else texts

    @classmethod
    def save_metadata(cls, save_to_path):
        # save pickle
        if osp.exists(save_to_path):
            print(f"File already exists at {save_to_path}.")
            return
        metadata = cls.get_metadata()
        with open(save_to_path, "wb") as f:
            pickle.dump(metadata, f)

    @classmethod
    def prompt_segments(cls, input_ids, tokenizer, image_token_id, image_token_idx=-200):
        def tokenize_segment(text: str, add_special_tokens: bool = False) -> List[int]:
            return tokenizer(text, add_special_tokens=add_special_tokens).input_ids

        prompt = tokenizer.decode(input_ids[:image_token_id]) + "<image>\n" + tokenizer.decode(input_ids[image_token_id + 1 :])
        segments = {}

        role_0 = "USER: "
        role_1 = "ASSISTANT:"

        user_split = prompt.split(role_0, 1)
        if len(user_split) > 1:
            segments["system"] = user_split[0].strip()
            remaining_text = role_0 + user_split[1]
        else:
            segments["system"] = prompt.strip()
            remaining_text = ""

        assistant_split = remaining_text.split(role_1, 1)
        if len(assistant_split) > 1:
            user_text = assistant_split[0].rstrip()
            segments["role_1"] = role_1.strip() + assistant_split[1].strip()
        else:
            segments["role_1"] = "".strip()
            user_text = remaining_text

        if "<image>" in user_text:
            user_text_part, inst_q_text = user_text.split("<image>", 1)
            segments["role_0"] = user_text_part.strip()
            segments["image"] = "<image>".strip()
            segments["inst_q"] = inst_q_text.strip()
        else:
            segments["role_0"] = user_text
            segments["image"] = None
            segments["inst_q"] = ""

        segments["system"] += " "
        segments["role_0"] += " "
        segments["inst_q"] = "\n" + segments["inst_q"] + " "

        tokenized_segments = {"system": tokenize_segment(segments["system"], False), "role_0": tokenize_segment(segments["role_0"], False), "inst_q": tokenize_segment(segments["inst_q"], False), "role_1": tokenize_segment(segments["role_1"], False), "image": [image_token_idx] if segments["image"] else []}

        token_id_range = {
            "system": [tokenized_segments["system"][0], tokenized_segments["role_0"][0]],
            "role_0": [tokenized_segments["role_0"][0], tokenized_segments["image"][0]],
            "image": [tokenized_segments["image"][0], tokenized_segments["inst_q"][0]],
            "inst_q": [tokenized_segments["inst_q"][0], tokenized_segments["role_1"][0]],
            "role_1": [tokenized_segments["role_1"][0], -1],
        }

        return token_id_range


class VisionDataSaver(LogicEngine):

    @classmethod
    def activate(cls):
        # 자식 클래스에서 자신에 대한 flag를 설정
        cls.set_flag(True)
        print(f"{cls.__name__} flag set to {cls.get_flag()}")

    @classmethod
    def set_after_ve(cls, after_ve):
        cls.__base__.after_ve = after_ve

    @classmethod
    def set_after_proj(cls, after_proj):
        cls.__base__.after_proj = after_proj
