from typing import List, Dict, Optional
from .base import BaseProcessor, ProcessTarget
import torch
import itertools as it
import warnings
class TokenTypeProcessor(BaseProcessor):
    """Processor for positional embeddings based on token types"""
    def __init__(self, num_variables: int, target: ProcessTarget = ProcessTarget.BOTH):
        super().__init__(target)
        self.num_variables = num_variables
        self.token_type_dict = self._create_token_type_dict()
        
    def _create_token_type_dict(self) -> Dict[str, int]:
        """Create a dictionary mapping token types to indices"""
        token_types = {
            'COEF': self.num_variables + 1,  # For coefficient tokens
        }
        # Add indices for variable exponent tokens
        for i in range(self.num_variables):
            token_types[f'VAR_{i}'] = i + 1
            
        # For other tokens
        token_types['OTHER'] = 0
        
        return token_types
    
    def _process_monomial(self, monomial: str) -> List[int]:
        """Process monomial and return list of token type IDs"""
        tokens = monomial.strip().split()
        
        token_ids = []
        var_id = 0
        for token in tokens: 
            if token.startswith('C'):
                token_ids.append(self.token_type_dict['COEF'])
            elif token.startswith('E'):
                token_ids.append(self.token_type_dict[f'VAR_{var_id}'])
                var_id += 1
        
        return token_ids
        
    def _process_polynomial(self, polynomial: str) -> List[int]:
        """Process polynomial and return list of token type IDs"""
        token_ids = []
        monomials = polynomial.split(' + ')
        for monomial in monomials:
            token_ids += self._process_monomial(monomial)
            token_ids += [self.token_type_dict['OTHER']]  # Add separator token type
        
        return token_ids[:-1]  # Remove last separator
    
    def _process(self, text: str) -> List[int]:
        """Process entire text and return list of token type IDs"""
        token_ids = []
        polys = text.split(' [SEP] ')
        for poly in polys:
            token_ids += self._process_polynomial(poly)
            token_ids += [self.token_type_dict['OTHER']]  # Add separator token type
        
        return token_ids[:-1]  # Remove last separator

    def __call__(self, texts: List[str]) -> List[List[int]]:
        """Process multiple texts"""
        return [self._process(text) for text in texts]
    
    
class MonomialTypeProcessor(BaseProcessor):
    """Processor that assigns IDs based on monomial types (exponent vectors)"""
    def __init__(self, target: ProcessTarget = ProcessTarget.BOTH):
        super().__init__(target)
        self.type_dict = {}  # Dictionary for exponent vector -> ID mapping
        self.exception_id = 0
        self.next_id = 1     # Next ID to assign
        
    def _get_exponent_vector(self, monomial: str) -> tuple:
        """Extract exponent vector from monomial"""
        tokens = monomial.strip().split()
        exponents = []
        for token in tokens:
            if token.startswith('E'):
                exponents.append(int(token[1:]))  # E2 -> 2
        return tuple(exponents)  # Convert to tuple (usable as dictionary key)
    
    def _get_monomial_type_id(self, monomial: str) -> int:
        """Get monomial type ID (assign new ID for new types)"""
        exp_vector = self._get_exponent_vector(monomial)
        if exp_vector not in self.type_dict:
            self.type_dict[exp_vector] = self.next_id
            self.next_id += 1
        return self.type_dict[exp_vector]
    
    def _process_monomial(self, monomial: str) -> List[int]:
        """Process monomial and return list of type_ids of same length"""
        tokens = monomial.strip().split()
        type_id = self._get_monomial_type_id(monomial)
        return [type_id] * len(tokens)
    
    def _process_polynomial(self, polynomial: str) -> List[int]:
        token_ids = []
        monomials = polynomial.split(' + ')
        for monomial in monomials:
            token_ids += self._process_monomial(monomial)
            token_ids += [self.exception_id]  # Assign self.exception_id for '+' part
        return token_ids[:-1]  # Remove the last '+' ID
    
    def _process(self, text: str) -> List[int]:
        token_ids = []
        polys = text.split(' [SEP] ')
        for poly in polys:
            token_ids += self._process_polynomial(poly)
            token_ids += [self.exception_id]  # Assign self.exception_id for '[SEP]' part
            
        return token_ids[:-1]  # Remove the last '[SEP]' ID

    def __call__(self, texts: List[str]) -> List[List[int]]:
        return [self._process(text) for text in texts]
    
    
from enum import Enum
from typing import List, Dict, Tuple
import itertools

class MonomialIDProcessor(BaseProcessor):
    """Processor that converts monomials to (coef_id, pattern_id)"""
    def __init__(self, num_variables: int, max_degree: int, max_coef: int, target: ProcessTarget = ProcessTarget.BOTH):
        super().__init__(target)
        self.num_variables = num_variables
        self.max_degree = max_degree
        self.max_coef = max_coef
        self.pattern_to_id = self._create_pattern_dict()
        self.coef_to_id = self._create_coef_dict()
        self.special_to_id = self._create_special_dict()
        
        # For reverse conversion
        self.id_to_pattern = {v: k for k, v in self.pattern_to_id.items()}
        self.id_to_coef = {v: k for k, v in self.coef_to_id.items()}
    
    def _create_pattern_dict(self) -> Dict[Tuple[int, ...], int]:
        """Enumerate possible exponent patterns and assign IDs"""
        # Enumerate combinations of values from 0 to max_degree for each variable
        patterns = list(itertools.product(
            range(self.max_degree + 1), 
            repeat=self.num_variables
        ))
        return {pattern: idx for idx, pattern in enumerate(patterns)}
        
    def _create_coef_dict(self) -> Dict[int, int]:
        """Map coefficients to IDs"""
        # Assign IDs from 0 to coefficients from 1 to max_coef
        return {i: i-1 for i in range(1, self.max_coef + 1)}
    
    def _create_special_dict(self) -> Dict[str, int]:
        """Map special tokens to IDs"""
        return {
            '[SEP]': 0,
            '[PAD]': 1,
            '<s>': 2,
            '</s>': 3,
            '+': 4
        }
    
    def _process_monomial(self, monomial: str) -> Tuple[int, int]:
        """Convert monomial to (coef_id, pattern_id)"""
        tokens = monomial.strip().split()
        coef = int(tokens[0][1:])  # "C5" -> 5
        exponents = tuple(int(e[1:]) for e in tokens[1:])  # ["E2","E1","E0"] -> (2,1,0)
        
        return (self.coef_to_id[coef], self.pattern_to_id[exponents])
    
    def _process_polynomial(self, polynomial: str) -> List[Tuple[int, int]]:
        """Convert polynomial to list of (coef_id, pattern_id)"""
        if not polynomial.strip():  # If empty string
            return []
        monomials = polynomial.split(' + ')
        special_tokens = [self.special_to_id['+'] for _ in monomials[:-1]] + [self.special_to_id['[SEP]']]
        
        return [(*self._process_monomial(mono), op) for mono, op in zip(monomials, special_tokens)]

    def _process(self, text: str) -> List[List[Tuple[int, int]]]:
        """Process entire text"""
        polys = text.split(' [SEP] ')
        processed = [self._process_polynomial(poly) for poly in polys]
        processed[-1][-1] = (*processed[-1][-1][:-1], self.special_to_id['</s>'])

        return processed

    def __call__(self, texts: List[str]) -> List[List[List[Tuple[int, int]]]]:
        """Process multiple texts"""
        return [self._process(text) for text in texts]


class MonomialProcessor(BaseProcessor):
    """Processor that converts monomials to (coef_id, exponents)"""
    def __init__(self, num_variables: int, max_degree: int, max_coef: int, target: ProcessTarget = ProcessTarget.BOTH):
        super().__init__(target)
        self.num_variables = num_variables
        self.max_degree = max_degree
        self.max_coef = max_coef
        self.coef_to_id = self._create_coef_dict()
        self.special_to_id = self._create_special_dict()
            
    def _create_coef_dict(self) -> Dict[int, int]:
        """Map coefficients to IDs"""
        # Assign IDs from 0 to coefficients from 1 to max_coef
        return {i: i-1 for i in range(1, self.max_coef + 1)}
    
    def _create_special_dict(self) -> Dict[str, int]:
        """Map special tokens to IDs"""
        return {
            '[SEP]': 0,
            '[PAD]': 1,
            '<s>': 2,
            '</s>': 3,
            '+': 4
        }
    
    def _process_monomial(self, monomial: str) -> Tuple[int]:
        """Convert monomial to (coef_id, pattern_id)"""
        tokens = monomial.strip().split()
        exponents = []
        for token in tokens:  # (prefix) "C5 E2 E1 E0" or (postfix) "E2 E1 E0 C5"
            if token.startswith('C'):
                coef = int(token[1:])  # "C5" -> 5
            elif token.startswith('E'):
                exponents += [int(token[1:])]  # "E2" -> 2
        
        return (self.coef_to_id[coef], *tuple(exponents))
    
    def _process_polynomial(self, polynomial: str) -> List[Tuple[int, int]]:
        """Convert polynomial to list of (coef_id, pattern_id)"""
        if not polynomial.strip():  # If empty string
            return []
        monomials = polynomial.split(' + ')
        special_tokens = [self.special_to_id['+'] for _ in monomials[:-1]] + [self.special_to_id['[SEP]']]
        
        return [(*self._process_monomial(mono), op) for mono, op in zip(monomials, special_tokens)]

    def _process(self, text: str) -> List[List[Tuple[int, int]]]:
        """Process entire text"""
        
        bos = [0] * (self.num_variables + 1) + [self.special_to_id['<s>']]
        eos = [0] * (self.num_variables + 1) + [self.special_to_id['</s>']]
        polys = text.split(' [SEP] ')
        processed = [bos] + [self._process_polynomial(poly) for poly in polys]
        processed[-1][-1] = (*processed[-1][-1][:-1], eos)

        return processed

    def __call__(self, texts: List[str]) -> List[List[List[Tuple[int, int]]]]:
        """Process multiple texts"""
        return [self._process(text) for text in texts]


class MonomialProcessorPlus(BaseProcessor):
    """Processor that converts monomials to (coef_id, exponents)"""
    def __init__(self, num_variables: int, max_degree: int, max_coef: int, target: ProcessTarget = ProcessTarget.BOTH):
        super().__init__(target)
        self.num_variables = num_variables
        self.max_degree = max_degree
        self.max_coef = max_coef
        self.coef_to_id = self._create_coef_dict()
        self.id_to_coef = {v: k for k, v in self.coef_to_id.items()}
        self.special_to_id = self._create_special_dict()
        self.id_to_special = {v: k for k, v in self.special_to_id.items()}
            
    def _create_coef_dict(self) -> Dict[int, int]:
        """Map coefficients to IDs"""
        # Assign IDs from 0 to coefficients from 0 to max_coef 
        return {i: i for i in range(self.max_coef + 1)}
    
    def _create_special_dict(self) -> Dict[str, int]:
        """Map special tokens to IDs"""
        return {
            '[SEP]': 0,
            '[PAD]': 1,
            '<s>': 2,
            '</s>': 3,
            '+': 4,
            '[BIGSEP]': 5
        }
    
    def _process_monomial(self, monomial: str) -> Tuple[int]:
        """Convert monomial to (coef_id, pattern_id)"""
        tokens = monomial.strip().split()
        exponents = []
        for token in tokens:  # (prefix) "C5 E2 E1 E0" or (postfix) "E2 E1 E0 C5"
            if token.startswith('C'):
                coef = int(token[1:])  # "C5" -> 5
            elif token.startswith('E'):
                exponents += [int(token[1:])]  # "E2" -> 2
        
        return [self.coef_to_id[coef]] + exponents
    
    def _process_polynomial(self, polynomial: str) -> List[Tuple[int, int]]:
        """Convert polynomial to list of (coef_id, pattern_id)"""
        if not polynomial.strip():  # If empty string
            return []
        monomials = polynomial.split(' + ')
        special_tokens = [self.special_to_id['+'] for _ in monomials[:-1]] + [self.special_to_id['[SEP]']]
        
        return [[*self._process_monomial(mono), op] for mono, op in zip(monomials, special_tokens)]

    def _process(self, text: str) -> List[List[Tuple[int, int]]]:
        """Process entire text"""
        
        bos = list([0] * (self.num_variables + 1) + [self.special_to_id['<s>']])
        eos = list([0] * (self.num_variables + 1) + [self.special_to_id['</s>']])
        
        # if not text:  # Adhoc handling. In case that on expansion is need (i.e., nothing to generate). 
        #     text = 'C1'  + ' E0' * self.num_variables
        #     self._process_polynomial(text)
            
        components = text.split(' [BIGSEP] ')
        processed = []
        for component in components:
            polys = component.split(' [SEP] ')
            _processed = [self._process_polynomial(poly) for poly in polys]
            # _processed[-1][-1] = list(*_processed[-1][-1][:-1], self.special_to_id['[BIGSEP]'])
            _processed = list(it.chain(*_processed))
            _processed[-1][-1] = self.special_to_id['[BIGSEP]']
            processed.extend(_processed)
        
        processed[-1][-1] = self.special_to_id['</s>']
        processed = [bos] + processed # + [eos]

        return processed

    def __call__(self, texts: List[str]) -> List[List[List[Tuple[int, int]]]]:
        """Process multiple texts"""        
        ret = [self._process(text) for text in texts]
        return ret


    def is_valid_monomial(self, texts: List[str]) -> List[bool]:
        
        return [self._is_valid_monomial(monomial_text) for monomial_text in texts]
    
    def _is_valid_monomial(self, monomial: str) -> bool:
        items = monomial.split()
        valid = items[0].startswith('C') and all([t.startswith('E') for t in items[1:-1]]) and items[-1] in self.special_to_id
        
        return valid

    def generation_helper(self, monomial_texts: List[str]) -> List[str]:
        monomials = [self._generative_helper(monomial_text) for monomial_text in monomial_texts]
        return monomials
    
    def _generative_helper(self, monomial_text: str) -> str:
        eos = [0] * (self.num_variables + 1) + [self.special_to_id['</s>']]
        
        valid = self._is_valid_monomial(monomial_text)

        if valid:
            special_token = monomial_text.split()[-1]
            monomial = list(self._process_monomial(monomial_text)) + [self.special_to_id[special_token]]
        else:
            monomial = eos

        return monomial
    
    
    def _decode_monomial_token(self, monomial: torch.Tensor, skip_special_tokens: bool = False) -> str:
        coeff, exponents, special_id = monomial[0].item(), monomial[1:-1], monomial[-1].item()
        
        special_token = self.id_to_special[special_id]
        is_eos = special_token == '</s>'
        
        if special_token == '<s>':
            monomial_text = '' if skip_special_tokens else '<s>'
        else: 
            if is_eos and skip_special_tokens: 
                special_token = ''
            
            monomial_text = ' '.join([f'C{self.id_to_coef[coeff]}'] + [f'E{e}' for e in exponents] + [special_token])
         
        return monomial_text.strip(), is_eos
    
    def decode(self, monomial_tokens: torch.Tensor, skip_special_tokens: bool = False, raise_warning: bool = True) -> List[str]:
        
        decoded_tokens = []
        for monomial in monomial_tokens:
            decoded_token, is_eos = self._decode_monomial_token(monomial, skip_special_tokens=skip_special_tokens)
            decoded_tokens.append(decoded_token)
            
            if is_eos:
                break
        
        # give warning if there is no eos token
        if (not is_eos) and raise_warning:
            warnings.warn(f'Generation ended before EOS token was found. If you are decoding a generated sequence, the max_length might be too small.')
        
        decoded_text = ' '.join(decoded_tokens).strip()
        
        return decoded_text
    
    def batch_decode(self, batch_monomial_tokens: torch.Tensor, skip_special_tokens: bool = True, raise_warning: bool = True) -> List[str]:
        return [self.decode(monomial_tokens, skip_special_tokens=skip_special_tokens, raise_warning=raise_warning) for monomial_tokens in batch_monomial_tokens]