from .base import BaseTokenizer
import torch
from typing import Dict, List, Union


class NumericalTokenizer(BaseTokenizer):
    def __init__(self, num_nodes, special_tokens=None, numbers=None):
        self.num_nodes = num_nodes
        if special_tokens is None:
            special_tokens = {'|': 0, '=': 1, '/': 2, '$': 3}
        self.special_tokens = special_tokens
        if numbers is None:
            numbers = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        self.numbers = numbers
        self.encoder = {str(i): i for i in range(num_nodes)}
        self.encoder.update({k: num_nodes + v for k, v in special_tokens.items()})
        self.decoder = {i: str(i) for i in range(num_nodes)}
        self.decoder.update({num_nodes + v: k for k, v in special_tokens.items()})
        self.decoder[-1] = ':'
        self.pad_token = '$'
        self.delimiter = '='
        self.delimiter_id = self.encoder[self.delimiter]
        self.padding_side = 'right'

    def encode(self, text: str) -> List[int]:
        out = []
        i = 0
        while i < len(text):
            if text[i] == ',':
                i += 1
                continue
            num = ''
            while i < len(text) and text[i] in self.numbers:
                num += text[i]
                i += 1
            out.append(self.encoder[num or text[i]])
            if not num:
                i += 1
        return out

    def decode(self, ids: List[int]) -> List[str]:
        if isinstance(ids, torch.Tensor):
            ids = ids.cpu().tolist()
        return [self.decoder[int(i)] for i in ids]

    def batch_decode(self, ids_list: List[List[int]], skip_special_tokens: bool = True, return_str: bool = True) -> List[str]:
        decoded = []
        for ids in ids_list:
            if skip_special_tokens:
                valid_ids = [id for id in ids if id not in [self.pad_token_id, -100]]
            else:
                valid_ids = [id for id in ids if id != -100]
            tokens = self.decode(valid_ids)
            if return_str:
                decoded.append(self._decode_tokens_to_str(tokens))
            else:
                decoded.append(tokens)
        return decoded
    
    def _decode_tokens_to_str(self, tokens: List[str]) -> str:
        decoded_str = ''
        for i, x in enumerate(tokens[:-1]):
            next_x = tokens[i + 1]
            if x not in self.special_tokens and next_x not in self.special_tokens:
                decoded_str += x + ','
            else:
                decoded_str += x
        decoded_str += tokens[-1]
        return decoded_str

    def __call__(
        self,
        text: Union[str, List[str]],
        max_length: int = None,
        padding: bool = False,
        truncation: bool = False,
        return_tensors: str = None
    ) -> Dict[str, torch.Tensor]:
        if isinstance(text, str):
            text = [text]
        encoded = [self.encode(t) for t in text]
        if padding:
            max_len = max_length
            padded = []
            attention_masks = []
            for ids in encoded:
                attention_mask = [1] * len(ids) + [0] * (max_len - len(ids))
                padded_ids = ids + [self.pad_token_id] * (max_len - len(ids))
                padded.append(padded_ids)
                attention_masks.append(attention_mask)
            encoded = padded
        else:
            attention_masks = [[1] * len(ids) for ids in encoded]
        if return_tensors == 'pt':
            input_ids = torch.tensor(encoded)
            attention_mask = torch.tensor(attention_masks)
            return {'input_ids': input_ids, 'attention_mask': attention_mask}
        return {'input_ids': encoded, 'attention_mask': attention_masks}

    @property
    def vocab_size(self) -> int:
        return len(self.encoder)

    @property
    def pad_token_id(self) -> int:
        return self.encoder[self.pad_token]

    @property
    def eos_token_id(self) -> int:
        return self.encoder[self.delimiter] 