""" Named entity recognition fine-tuning: utilities to work with CLUENER task. """
import torch
import logging
import os
import copy
import json
from .utils_ner import DataProcessor
import random
from random import shuffle

logger = logging.getLogger(__name__)

class InputExample(object):
    """A single training/test example for token classification."""
    def __init__(self, guid, text_a, labels):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: list. The words of the sequence.
            labels: (Optional) list. The labels for each word of the sequence. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.labels = labels

    def __repr__(self):
        return str(self.to_json_string())
    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output
    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

class InputFeatures(object):
    """A single set of features of data."""
    def __init__(self, input_ids, input_mask, input_len,segment_ids, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.input_len = input_len

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

def collate_fnOLD(batch):
    """
    batch should be a list of (sequence, target, length) tuples...
    Returns a padded tensor of sequences sorted from longest to shortest,
    """
    all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))
    """
    max_len = max(all_lens).item()
    all_input_ids = all_input_ids[:, :max_len]
    all_attention_mask = all_attention_mask[:, :max_len]
    all_token_type_ids = all_token_type_ids[:, :max_len]
    all_labels = all_labels[:,:max_len]
    """
    return all_input_ids, all_attention_mask, all_token_type_ids, all_labels,all_lens

def collate_fn(batch, tokenizer):
    """处理批次数据，动态生成掩码位置和标签"""
    
    # 解包批次数据
    all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))

    # 动态生成 masked_pos 和 masked_labels
    masked_pos_list = []
    masked_labels_list = []
    max_pred = 5  # 每个样本最多掩码的 token 数
    special_ids = {
        tokenizer.cls_token_id, 
        tokenizer.sep_token_id, 
        tokenizer.pad_token_id,
        tokenizer.mask_token_id  # 如果需要排除 [MASK]
    }

    for input_ids in all_input_ids:
        # 候选掩码位置（排除特殊 token）
        cand_positions = [
            i for i, token in enumerate(input_ids)
            if token not in special_ids
        ]
        n_pred = min(max_pred, max(1, int(round(len(cand_positions) * 0.15))))
        shuffle(cand_positions)
        
        masked_pos = []
        masked_labels = []
        for pos in cand_positions[:n_pred]:
            masked_pos.append(pos)
            masked_labels.append(input_ids[pos].item())
            input_ids[pos] = tokenizer.mask_token_id
            
        
        # 填充无效位置
        masked_pos.extend([-1] * (max_pred - len(masked_pos)))
        masked_labels.extend([-100] * (max_pred - len(masked_labels)))
        masked_pos_list.append(masked_pos)
        masked_labels_list.append(masked_labels)

    # 转换为张量
    masked_pos = torch.tensor(masked_pos_list, dtype=torch.long)
    masked_labels = torch.tensor(masked_labels_list, dtype=torch.long)

    return (
        all_input_ids,
        all_attention_mask,
        all_token_type_ids,
        all_labels,
        all_lens,
        masked_pos,
        masked_labels
    )

def is_numeric_value(s):
    """
    判断字符串是否可以转换为数值（包括整数、小数、负数等）
    """
    try:
        # 移除可能的千位分隔符逗号
        s_clean = s.replace(',', '')
        float(s_clean)
        return True
    except ValueError:
        return False

def convert_examples_to_features(examples,label_list,max_seq_length,tokenizer,
                                 cls_token_at_end=False,cls_token="[CLS]",cls_token_segment_id=1,
                                 sep_token="[SEP]",pad_on_left=False,pad_token=0,pad_token_segment_id=0,
                                 sequence_a_segment_id=0,mask_padding_with_zero=True,):
    """ Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """

    label_map = {label: i for i, label in enumerate(label_list)}
    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d", ex_index, len(examples))
        if isinstance(example.text_a,list):
            example.text_a = " ".join(example.text_a)
        tokens = tokenizer.tokenize(example.text_a)
        label_ids = [label_map[x] for x in example.labels]
        
        #处理英文分词器导致的token-label不对齐
        label_mapping = {
            0: 0,   # O -> O
            9: 9,   # 其他非实体标签 -> 保持
            1: 5,   # B-PER -> I-PER
            5: 5,   # I-PER -> I-PER
            2: 6,   # B-LOC -> I-LOC  
            6: 6,   # I-LOC -> I-LOC
            3: 7,   # B-ORG -> I-ORG
            7: 7,   # I-ORG -> I-ORG
            4: 8,   # B-MISC -> I-MISC
            8: 8,   # I-MISC -> I-MISC
        }
        new_label_ids =[]
        i=j=1
        new_label_ids.append(label_ids[0])
        while i < len(tokens) and j < len (label_ids):
        
            if tokens[i].startswith('##'):
                # 根据当前标签确定要插入的标签
                current_label = new_label_ids[-1]
                if current_label in label_mapping:
                    inserted_label = label_mapping[current_label]
                else:
                    inserted_label = current_label  # 如果不在映射中，保持原样
            
                new_label_ids.append(inserted_label)
                i = i+1
            elif tokens[i] == '-':
                tokens.pop(i)
                if i < len(tokens):
                    tokens.pop(i)
            elif i+2<len(tokens) and tokens[i] == 'n' and tokens[i+1] == "'" and tokens[i+2] == 't':
                tokens[i] = 'not'
                tokens.pop(i+1)
                tokens.pop(i+1)
                new_label_ids.append(label_ids[j])
                i = i + 1
                j = j + 1
            elif i+1< len(tokens) and tokens[i] == "'" and tokens[i+1] == 's':
                new_label_ids.append(label_ids[j])
                new_label_ids.append(label_ids[j])
                i=i+2
                j=j+1
            elif i+2< len(tokens) and is_numeric_value(tokens[i]) and tokens[i+1] == ',' and is_numeric_value(tokens[i+2]):
                tokens.pop(i+1)
                tokens.pop(i+1)
                new_label_ids.append(label_ids[j])
                i = i + 1
                j = j + 1
            elif i+2<len(tokens) and is_numeric_value(tokens[i]) and tokens[i+1]== '.' and not is_numeric_value(tokens[i+2]):
                new_label_ids.append(label_ids[j])
                new_label_ids.append(0)
                i=i+2
                j=j+1
            else :
                new_label_ids.append(label_ids[j])
                i = i + 1
                j = j + 1
        while i < len(tokens) and tokens[i] == '-':  
            tokens.pop(i)
            if i <len(tokens):
                tokens.pop(i)
        while len(tokens) != len(new_label_ids):
            new_label_ids.append(0)
        label_ids = new_label_ids
        for idx, val in enumerate(label_ids):
            if not isinstance(val, int):
                logger.warning(f"label_ids 第 {idx} 个元素类型异常：{type(val)}（值：{val}）")
        
        # Account for [CLS] and [SEP] with "- 2".
        special_tokens_count = 2
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[: (max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (max_seq_length - special_tokens_count)]

        tokens += [sep_token]
        label_ids += [label_map['O']]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if cls_token_at_end:
            tokens += [cls_token]
            label_ids += [label_map['O']]
            segment_ids += [cls_token_segment_id]
        else:
            tokens = [cls_token] + tokens
            label_ids = [label_map['O']] + label_ids
            segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
        input_len = len(label_ids)
        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            label_ids = ([pad_token] * padding_length) + label_ids
        else:
            input_ids += [pad_token] * padding_length
            input_mask += [0 if mask_padding_with_zero else 1] * padding_length
            segment_ids += [pad_token_segment_id] * padding_length
            label_ids += [pad_token] * padding_length
        
            
        #print(f"label_ids:",len(label_ids))
        #print(f"input_ids:",len(input_ids))

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s", example.guid)
            logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
            logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))

        features.append(InputFeatures(input_ids=input_ids, input_mask=input_mask,input_len = input_len,
                                      segment_ids=segment_ids, label_ids=label_ids))
    return features





class CnerProcessor(DataProcessor):
    """Processor for the chinese ner data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_text(os.path.join(data_dir, "train.char.bmes")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_text(os.path.join(data_dir, "dev.char.bmes")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_text(os.path.join(data_dir, "test.char.bmes")), "test")

    def get_labels(self):
        """See base class."""
        return ["X",'B-CONT','B-EDU','B-LOC','B-NAME','B-ORG','B-PRO','B-RACE','B-TITLE',
                'I-CONT','I-EDU','I-LOC','I-NAME','I-ORG','I-PRO','I-RACE','I-TITLE',
                'O','S-NAME','S-ORG','S-RACE',"[START]", "[END]"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a= line['words']
            # BIOS
            labels = []
            for x in line['labels']:
                if 'M-' in x:
                    labels.append(x.replace('M-','I-'))
                elif 'E-' in x:
                    labels.append(x.replace('E-', 'I-'))
                else:
                    labels.append(x)
            examples.append(InputExample(guid=guid, text_a=text_a, labels=labels))
        return examples

class ConllProcessor(DataProcessor):
    """Processor for the conll2003 ner data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_text(os.path.join(data_dir, "train.txt")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_text(os.path.join(data_dir, "valid.txt")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_text(os.path.join(data_dir, "test.txt")), "test")

    def get_labels(self):
        """See base class."""
        return ["X",'B-PER','B-ORG','B-LOC','B-MISC',
                'I-PER','I-ORG','I-LOC','I-MISC',
                'O',"[START]", "[END]"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a= line['words']
            # BIOS
            labels = []
            for x in line['labels']:
                if 'M-' in x:
                    labels.append(x.replace('M-','I-'))
                elif 'E-' in x:
                    labels.append(x.replace('E-', 'I-'))
                else:
                    labels.append(x)
            examples.append(InputExample(guid=guid, text_a=text_a, labels=labels))
        for ex in examples[:3]:
            print(ex.guid)
            print("words :", ex.text_a)
            print("labels:", ex.labels)
            assert len(ex.text_a) == len(ex.labels), "长度不等：words 与 labels 不匹配"
            print("-" * 60)
        return examples

class CluenerProcessor(DataProcessor):
    """Processor for the chinese ner data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_json(os.path.join(data_dir, "dev.json")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), "test")

    def get_labels(self):
        """See base class."""
        return ["X", "B-address", "B-book", "B-company", 'B-game', 'B-government', 'B-movie', 'B-name',
                'B-organization', 'B-position','B-scene',"I-address",
                "I-book", "I-company", 'I-game', 'I-government', 'I-movie', 'I-name',
                'I-organization', 'I-position','I-scene',
                "S-address", "S-book", "S-company", 'S-game', 'S-government', 'S-movie',
                'S-name', 'S-organization', 'S-position',
                'S-scene','O',"[START]", "[END]"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a= line['words']
            # BIOS
            labels = line['labels']
            examples.append(InputExample(guid=guid, text_a=text_a, labels=labels))
        return examples

class MSRAProcessor(DataProcessor):
    """MSRA 数据集处理器"""
    def __init__(self):
        super().__init__()
        self.labels = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]  # MSRA 数据集的标签体系

    def get_train_examples(self, data_dir):
        return self._create_examples(self._read_txt(os.path.join(data_dir, "train.txt")), "train")

    def get_dev_examples(self, data_dir):
        return self._create_examples(self._read_txt(os.path.join(data_dir, "dev.txt")), "dev")

    def get_test_examples(self, data_dir):
        return self._create_examples(self._read_txt(os.path.join(data_dir, "test.txt")), "test")

    def get_labels(self):
        """获取数据集的标签列表"""
        return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O", "[START]", "[END]"]

    def _read_txt(self, input_file):
        """读取 MSRA 数据集的文本文件"""
        with open(input_file, 'r', encoding='utf-8') as f:
            lines = []
            words = []
            labels = []
            for line in f:
                line = line.strip()
                if not line:
                    if words:
                        lines.append({'words': words, 'labels': labels})
                        words = []
                        labels = []
                else:
                    parts = line.split()
                    words.append(parts[0])
                    labels.append(parts[-1])
            if words:
                lines.append({'words': words, 'labels': labels})
        return lines

    def _create_examples(self, lines, set_type):
        """从数据行创建 InputExample 对象"""
        examples = []
        for i, line in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            words = line['words']
            labels = line['labels']
            examples.append(InputExample(guid=guid, text_a=words, labels=labels))
        return examples

class WNUT17Processor(DataProcessor):
    """MSRA 数据集处理器"""
    def __init__(self):
        super().__init__()
        self.labels = ["B-person", "I-person", "B-location", "I-location", "B-corporation", "I-corporation",
                "B-creative-work","I-creative-work","B-product","I-product",
                "B-group","I-group","O"]  # MSRA 数据集的标签体系

    def get_train_examples(self, data_dir):
        return self._create_examples(self._read_txt(os.path.join(data_dir, "train.txt")), "train")

    def get_dev_examples(self, data_dir):
        return self._create_examples(self._read_txt(os.path.join(data_dir, "dev.txt")), "dev")

    def get_test_examples(self, data_dir):
        return self._create_examples(self._read_txt(os.path.join(data_dir, "test.txt")), "test")

    def get_labels(self):
        """获取数据集的标签列表"""
        return ["B-person", "I-person", "B-location", "I-location", "B-corporation", "I-corporation",
                "B-creative-work","I-creative-work","B-product","I-product",
                "B-group","I-group","O", "[START]", "[END]"]

    def _read_txt(self, input_file):
        """读取 MSRA 数据集的文本文件"""
        with open(input_file, 'r', encoding='utf-8') as f:
            lines = []
            words = []
            labels = []
            for line in f:
                line = line.strip()
                if not line:
                    if words:
                        lines.append({'words': words, 'labels': labels})
                        words = []
                        labels = []
                else:
                    parts = line.split()
                    words.append(parts[0])
                    labels.append(parts[-1])
            if words:
                lines.append({'words': words, 'labels': labels})
        return lines

    def _create_examples(self, lines, set_type):
        """从数据行创建 InputExample 对象"""
        examples = []
        for i, line in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            words = line['words']
            labels = line['labels']
            examples.append(InputExample(guid=guid, text_a=words, labels=labels))
        return examples


ner_processors = {
    "cner": CnerProcessor,
    'cluener':CluenerProcessor,
    "msra":MSRAProcessor,  # 新增 MSRA 任务的处理器类
    'conll03':ConllProcessor,
    'wnut17':WNUT17Processor
}
