"""STAT Implementation"""
import torch
import abc, json
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Literal, Optional, Sequence, cast, overload, Tuple
import string
import random
import numpy as np
from scipy.special import softmax
from collections import deque
import transformers 
from .utils import tokenizer_utils
from .utils.typing import (
    Model,
    Dataset,
    Device,
    ModelInput,
    ModelOutput,
    StrSequence,
    Tokenizer,
    TokenizerOffsetMapping,
)


class STAT(abc.ABC):
    """
    Create STAT to steer attentions of transformer models. 

    Args:
        model ([`transformers.PreTrainedModel`]): The model to be steered. 
        tokenizer ([`transformers.PreTrainedTokenizer`]): The model's tokenizer. 
        # head_config (`dict`): The config to control which attention heads to be steered. 
        # alpha (`float`): The scaling coefficient of attention steering. 
        # scale_position (`str`): To upweight the scores of highlighted tokens (`include`), 
        #     or to downweight those of unselected tokens (`exclude`). 

    Returns:
        STAT: STAT steerer that can register the pre-forward hooks on target models. 

    """

    ATTN_MODULE_NAME = {
        "gptj": "transformer.h.{}.attn",
        "llama": "model.layers.{}.self_attn",
        "mistral": "model.layers.{}.self_attn",
        "gemma": "model.layers.{}.self_attn",
        "phi3mini": "model.layers.{}.self_attn"
    }
    ATTENTION_MASK_ARGIDX = {
        "gptj": 2, 
        "llama": 1, 
        "mistral": 1, 
        "gemma": 1,
    }
    def __init__(
        self, 
        model: Model, 
        tokenizer: Tokenizer, 
        stat_mode: str,
        head_config: dict,
        scale_static: float,
        scale_dynamic: float
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.stat_mode = stat_mode
        self.model_name = "llama"
        self.setup_model(model)
        self.setup_head_config(head_config)
        
        ## for stat
        self.num_stage = None
        self.stage_type = None
        self.num_branch = None
        # self.candidate_stage = None
        self.current_stage = None
        self.candidate_branch = None
        self.select_branch = None
        self.alpha = 0.5
        self.eps = 1e-38
        self.scale_static = scale_static
        self.scale_dynamic = scale_dynamic

    def setup_model(self, model):
        """Obtain the model type and complete the configuration."""
        if isinstance(model, transformers.LlamaForCausalLM):
            self.model_name = "llama"
            self.num_attn_head = model.config.num_attention_heads
            self.num_atten_layer = model.config.num_hidden_layers
        elif isinstance(model, transformers.GPTJForCausalLM):
            self.model_name = "gptj"
            self.num_attn_head = model.config.n_head
            self.num_atten_layer = model.config.n_layer 
        elif isinstance(model, transformers.MistralForCausalLM):
            self.model_name = "mistral"
            self.num_attn_head = model.config.num_attention_heads
            self.num_atten_layer = model.config.num_hidden_layers
        elif isinstance(model, transformers.GemmaForCausalLM):
            self.model_name = "gemma"
            self.num_attn_head = model.config.num_attention_heads
            self.num_atten_layer = model.config.num_hidden_layers
        elif model.__class__.__name__ == "Phi3ForCausalLM":
            self.model_name = "phi3mini"
            self.num_attn_head = model.config.num_attention_heads
            self.num_atten_layer = model.config.num_hidden_layers
        else:
            raise ValueError("Unimplemented Model Type.")
        
    def setup_head_config(self, head_config):
        """
        Config the attention heads to be steered.

        If `head_config` is `list` of layer index, PASTA will steer the entire layers. 
        """
       
        if isinstance(head_config, dict):
            self.head_config = {int(k):v for k,v in head_config.items()} 
            self.all_layers_idx = [int(key) for key in head_config]
        elif isinstance(head_config, list):
            self.all_layers_idx = [int(v) for v in head_config]
            self.head_config = {
                idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx
            }
        else:
            raise ValueError(f"Incorrect head config: {head_config}")
    
    
    def find_token_ranges(self, strings, seg, offset_mappings):
        start_idx = strings.find(seg)
        if start_idx == -1:
            return None
        end_idx = start_idx + len(seg)
        token_start = token_end = None
        for i, (start, end) in enumerate(offset_mappings):
            if token_start is None:  ## first meet
                if start <= start_idx <= end:
                    token_start = i
            if token_end is None:
                if start <= end_idx <= end:
                    token_end = i
                    break
        
        assert token_start is not None
        assert token_end is not None
        assert token_start <= token_end
        return (token_start, token_end + 1)

    def remove_intersection_token(self, token_range1, token_range2):
        set1 = set(token_range1)
        set2 = set(token_range2)

        intersection = set1 & set2
        token_range1 = list(set1 - intersection)
        token_range2 = list(set2 - intersection)

        return token_range1, token_range2

    def safe_mean(self, scores, tokens, mask_heads):
        if len(tokens) == 0:
            return 0
        heads, rows = torch.meshgrid(torch.tensor(mask_heads), torch.tensor(tokens), indexing="ij")
        selected_scores = scores[:, heads, :, rows]  # [head, |tokens|, seq_len]
        val = selected_scores.mean()

        if isinstance(val, torch.Tensor):
            val = val.item()
        return val
    
    def check_unit_num(self, token_range):
        num_unit = 0
        for unit in token_range:
            if len(unit)>0:
                num_unit = num_unit+1
        if num_unit < 2:
            return False
        else:
            return True
    
    def entropy(self, prob_list):
        prob_array = np.array(prob_list+self.eps)
        entropy = -np.sum(prob_array * np.log(prob_array))
        return entropy
    
    def detect_stage(self, attention_scores, token_ranges, key, head_idx):
        avg_strings = []
        for token in token_ranges[key]:
            avg_strings.append(self.safe_mean(attention_scores, token, head_idx))
        nor_avg = np.array(avg_strings) / sum(avg_strings)
        max_index = np.argmax(nor_avg)
        entropy_avg = self.entropy(nor_avg)

        if (max_index == self.current_stage+1) and (entropy_avg < self.alpha*np.log(self.num_stage)):
            return True    
        else:
            return False
    
    def detect_branch(self, attention_scores, token_ranges, key, head_idx):
        avg_strings = []
        for token in token_ranges[key]:
            avg_strings.append(self.safe_mean(attention_scores, token, head_idx))
        avg_strings[0] = avg_strings[0]*0.8
        nor_avg = np.array(avg_strings) / sum(avg_strings)
        max_index = np.argmax(nor_avg)
        entropy_avg = self.entropy(nor_avg)
        if entropy_avg < self.alpha*np.log(self.num_branch):
            return max_index    
        else:
            return -1  
    
    def check_consistency(self, candidate_branch):
        if len(candidate_branch)==2 and len(set(candidate_branch)) == 1:
            if candidate_branch[-1] > -1:
                return True
        else:
            False

    def select_unit_id(self,  
        module: torch.nn.Module, 
        input: tuple,
        output: tuple,
        head_idx: list[int],
        token_ranges: dict):
        '''
        根据attn score来判断要屏蔽哪些unit
        '''
        attention_scores = output[1]  # after softmax
        # bs, heads, seq1, seq2 = attention_scores.shape 

        if self.num_stage > 1:
            if self.check_unit_num(token_ranges[self.stage_type]):
                stage_change = self.detect_stage(attention_scores, token_ranges, self.stage_type, head_idx)
                if stage_change:
                    self.current_stage = min(self.current_stage+1, self.num_stage-1)
                    self.select_branch = None
                    self.candidate_branch = deque(maxlen=2)
        if "branch" in token_ranges.keys():
            if self.select_branch == None:
                branch_cand = self.detect_branch(attention_scores, token_ranges, "branch", head_idx)
                if branch_cand>-1:
                    self.select_branch = branch_cand

    def token_stage_mask(self, attention_mask, token_ranges, key, mask_heads):
        for i in range(len(token_ranges[key])):
            mask_sig = False
            if key == "chain":
                if i>self.current_stage:
                    mask_sig = True
            else:
                if i!=self.current_stage:
                    mask_sig = True
            if mask_sig:
                if (torch.tensor(mask_heads).dtype != torch.tensor(token_ranges[key][i]).dtype):
                    continue
                _, token2 = self.remove_intersection_token(token_ranges[key][self.current_stage], token_ranges[key][i])
                heads, rows = torch.meshgrid(torch.tensor(mask_heads), torch.tensor(token2), indexing="ij")
                dtype, device = attention_mask.dtype, attention_mask.device
                attention_mask[:, heads, :, rows] += torch.Tensor([self.scale_dynamic]).to(dtype).to(device).log()
                                                
        return attention_mask
    
    def token_branch_mask(self, attention_mask, token_ranges, key, mask_heads):
        for i in range(len(token_ranges[key])):
            if i!=self.select_branch:
                if (torch.tensor(mask_heads).dtype != torch.tensor(token_ranges[key][i]).dtype):
                    continue
                _, token2 = self.remove_intersection_token(token_ranges[key][self.select_branch], token_ranges[key][i])
                heads, rows = torch.meshgrid(torch.tensor(mask_heads), torch.tensor(token2), indexing="ij")
                dtype, device = attention_mask.dtype, attention_mask.device
                attention_mask[:, heads, :, rows] += torch.Tensor([self.scale_dynamic]).to(dtype).to(device).log()
                                                
        return attention_mask

    def structure_mask(self, attention_mask, token_ranges, key, mask_heads):
        for i in range(len(token_ranges[key])):
            for j in range(i+1, len(token_ranges[key])):
                if (torch.tensor(token_ranges[key][i]).dtype != torch.tensor(token_ranges[key][j]).dtype):
                    continue
                token1, token2 = self.remove_intersection_token(token_ranges[key][i], token_ranges[key][j])
                heads, rows, cols = torch.meshgrid(torch.tensor(mask_heads), torch.tensor(token1), torch.tensor(token2), indexing="ij")
                dtype, device = attention_mask.dtype, attention_mask.device
                if key in ("branch","parallel"):
                    attention_mask[:, heads,cols, rows] += torch.Tensor([self.scale_static]).to(dtype).to(device).log()
                attention_mask[:, heads,rows, cols] += torch.Tensor([self.scale_static]).to(dtype).to(device).log()
                
        return attention_mask

    def edit_attention(self, 
        module: torch.nn.Module, 
        input_args: tuple,
        input_kwargs: dict, 
        head_idx: list[int],
        token_ranges: dict):
        # """
        #     修改 Attention Mask 的 Hook 函数。
        #     Args:
        #         module (torch.nn.Module): 当前 Attention 模块。
        #         head_idx (list): 需要引入屏蔽机制的head idx。
        #         token_ranges (dict): 包含高优先级和低优先级 Token 的索引范围。
        #         output (tuple): Attention 模块的输出，一个元组 (hidden_states, attention_scores)。
        #   """
        if "attention_mask" in input_kwargs:
            attention_mask = input_kwargs['attention_mask'].clone().detach()
        elif input_args is not None:
            arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name]
            attention_mask = input_args[arg_idx].clone().detach()
        else:
            raise ValueError(f"Not found attention masks in {str(module)}")
        bs, heads, seq1, seq2 = attention_mask.size()
        if heads != self.num_attn_head:
            attention_mask = attention_mask.expand(
                bs, self.num_attn_head, seq1, seq2
            ).clone()


        # statics relative masking
        if seq1==seq2:
            if self.stat_mode in ("relative_only","relative+span"):
                for key in token_ranges.keys():
                    if self.check_unit_num(token_ranges[key]):
                        attention_mask = self.structure_mask(attention_mask, token_ranges, key, head_idx)
        # dynamic span masking
        else:
            if self.stat_mode in ("span_only","relative+span"):
                if self.num_stage > 1:
                    if self.check_unit_num(token_ranges[self.stage_type]):
                        attention_mask = self.token_stage_mask(attention_mask, token_ranges, self.stage_type, head_idx)
                if self.num_branch > 1 and self.select_branch != None:
                    if self.check_unit_num(token_ranges["branch"]):
                        attention_mask = self.token_branch_mask(attention_mask, token_ranges, "branch", head_idx)                

        
        if "attention_mask" in input_kwargs:
            input_kwargs['attention_mask'] = attention_mask 
            return input_args, input_kwargs
        else:
            return (input_args[:arg_idx], attention_mask, input_args[arg_idx+1:]), input_kwargs


    @contextmanager
    def apply_steering(self, model, strings, substring_groups, model_input, offsets_mapping: Sequence[TokenizerOffsetMapping], occurrence: int = 0):
        """
        The function of context manager to register the pre-forward hook on `model`. 
        Args:
            model ([`transformers.PreTrainedModel`]): The transformer model to be steered. 
            strings (`str`): The input strings. 
            substring_groups (`dict{"chain":[[],[]]}): The structure spans for each string. 
            model_input (`transformers.BatchEncoding`): The tokenized model inputs. 
            offsets_mapping (`TokenizerOffsetMapping`): The offset mapping outputed by
                the tokenizer when encoding the `strings`. 
        """
        ### prepare token ranges
        token_ranges = {}
        punctuations = string.punctuation + "，。！？；：“”‘’《》【】（）「」"
        for key in substring_groups.keys():
            token_ranges[key] = []
            text_group = substring_groups[key]
            for index in range(len(text_group)):
                token_ranges[key].insert(index, [])
                for seg in text_group[index]:
                    if isinstance(seg,list) and len(seg)==1:
                        seg = seg[0].rstrip(punctuations)
                    print("seg: {}".format(seg))
                    rng = self.find_token_ranges(strings, seg, offsets_mapping)
                    if rng is not None and rng!=():
                        rng_list = list(range(rng[0], rng[1]))
                        if len(rng_list)>0:
                            token_ranges[key][index].extend(rng_list)
                            token_ranges[key][index] = sorted(set(token_ranges[key][index]))

        print("token ranges: {}".format(token_ranges))

        ## prepare structure settings
        if "branch" in token_ranges.keys():
            self.num_branch = len(token_ranges["branch"])
        else:
            self.num_branch = 0
        if "parallel" in token_ranges.keys():
            self.num_stage = len(token_ranges["parallel"])
            self.stage_type = "parallel"
        elif "chain" in token_ranges.keys():
            self.num_stage = len(token_ranges["chain"])
            self.stage_type = "chain"
        else:
            self.num_stage = 1
            self.stage_type = None

        ### initialize masking para
        self.current_stage = 0
        self.candidate_branch = deque(maxlen=2)
        self.select_branch = None

        ### hook function for forward 
        registered_hooks = []
        for layer_idx in self.all_layers_idx:
            name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx)
            module = model.get_submodule(name)
            select_unit = partial(
                self.select_unit_id, 
                head_idx = self.head_config[layer_idx],
                token_ranges = token_ranges
            )
            mask_unit = partial(
                self.edit_attention, 
                head_idx = self.head_config[layer_idx],
                token_ranges = token_ranges
            )
            registered_hook_pre = module.register_forward_pre_hook(mask_unit, with_kwargs=True)
            registered_hook = module.register_forward_hook(select_unit)
            registered_hooks.append(registered_hook_pre)
            registered_hooks.append(registered_hook)
        try:
            yield model
        except Exception as error:
            raise error
        finally:
            for registered_hook in registered_hooks:
                registered_hook.remove()


    def inputs_from_batch(self, text):
        """Precompute model inputs."""
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            return_offsets_mapping=True,
        )
        offset_mapping = inputs.pop("offset_mapping")[0]
        inputs = inputs.to(self.model.device)

        return inputs, offset_mapping
