from functools import reduce
from itertools import chain 
import re 
import numpy as np 

def monom_to_prefix(coeff, exponent, symbol='x', split_rational=True):
   '''
   coeff: scalar
   exponent: tuple
   '''

   assert(coeff != 0)

   if np.all([ei == 0 for ei in exponent]): 
      if '/' in str(coeff) and split_rational:
         a, b = str(coeff).split('/')
         return ['/', f'C{a}', f'C{b}']
      else:
         return [f'C{coeff}']
   
   n = len(exponent)
   gens = [f'{symbol}{i}' for i in range(n)]
   # variable_wise = [['^', xi, str(ei)] if ei > 1 else [xi] for xi, ei in zip(gens, exponent) if ei > 0]
   variable_wise = [['^', xi, f'E{ei}'] if ei > 1 else [xi] for xi, ei in zip(gens, exponent) if ei > 0]
   term = reduce(lambda a, b: ['*', *b, *a] , variable_wise[::-1])  # the first argument a is the accumulated results
   
   if coeff != 1:
      if '/' in str(coeff):
         a, b = str(coeff).split('/')
         if split_rational:
            monom_prefix = ['*', '/', f'C{a}', f'C{b}', *term]
         else:
            monom_prefix = ['*', f'C{a}/{b}', *term]
      else:
         # monom_prefix = ['*', str(coeff), *term]
         monom_prefix = ['*', f'C{coeff}', *term]
   else: 
      monom_prefix = term
   
   return monom_prefix

def poly_to_prefix(p, split_rational=True):
   '''
   p: sage polynomial
   '''
   if p.is_zero(): return 'C0'
   if p.is_constant(): return str(p.coefficients()[0])
   
   d = p.dict()  # ordered dict
   monom_prefix_list = [monom_to_prefix(c, es, split_rational=split_rational) for es, c in d.items() if c != 0]
   
   poly_prefix = reduce(lambda a, b: ['+', *b, *a], monom_prefix_list[::-1]) # the first argument a is the accumulated results
   
   return ' '.join(poly_prefix)

def sinfix_to_infix(sinfix):
   '''
   sinfix: infix obtained by str(p) for p: polynomial in sage. 
   example --- "-2*x0^2 - x0*x2 + x0 + 1" (no space inside monomial)

   output: put space between tokens
   example --- "-2 * x0 ^ 2 - x0 * x2 + x0 + 1"
   '''
   tokens = sinfix.split()
   
   tokens = [token if token in ('+', '-') else re.split('[*^/]', token) for token in tokens]
   return ' '.join(tokens)

def poly_to_infix(p, split_rational=True, keep_tag=True):
   '''
   p: sage polynomial
   '''
   return prefix_to_infix(poly_to_prefix(p, split_rational=split_rational), split_rational=split_rational, keep_tag=keep_tag)

def prefix_to_poly(prefix, ring, return_empty_for_invalid=True, split_rational=True):
   if return_empty_for_invalid:
      try:
         infix = prefix_to_infix(prefix, split_rational=split_rational, ring=ring)
      except:
         return None

   return ring(infix)

def infix_to_poly(infix, ring):
   return ring(infix)

def prefix_to_infix(prefix, split_rational=True, ring=None, keep_tag=True):
   '''
   prefix: prefix representation of a polynomial
   '''
   is_tokenized = isinstance(prefix, list)
   if not is_tokenized:
      prefix = prefix.split()
   
   base = ring.base_ring() if ring is not None else None
   if base is QQ:
      f2r = lambda x: str(base(float(x))) if '.' in x else x
   else:
      f2r = lambda x: x
   
   stack = []
   prefix = prefix[::-1] 
   # Process each token in the prefix expression
   for token in prefix:
      if '/' in token and not split_rational:
         assert(token[0] == 'C')
         if keep_tag:
            stack.append(f2r(token))
         else:
            stack.append(f2r(token[1:]))
      elif token in ["+", "-", "*", "/", "^"]:
         # Pop the top two elements from the stack and apply the operator
         arg1 = stack.pop()
         arg2 = stack.pop()
         stack.append(f"{arg1} {token} {arg2}")
      else:
         if token[0] == 'E' and not keep_tag: token = token[1:]  # remove exponent tag.
         if token[0] == 'C' and not keep_tag: token = f2r(token[1:])  # remove coeff tag.
         stack.append(token)
            
   infix = str(stack[0])
   if is_tokenized: infix = infix.split()
   
   return infix

def poly_to_sequence(poly, encoding=None):
   seq = []
   if poly.is_zero(): 
      seq = ['C0'] + [f'E0' for _ in poly.args()]
   else:
      for e, c in poly.dict().items():
         seq += [f'C{c}'] + [f'E{ei}' for ei in e]
         seq += ['+']
      seq = seq[:-1]
   seq = ' '.join(seq)
   return seq

def sequence_to_poly(seq, ring, encoding=None):
   monoms = seq.split('+')
   d = {}
   for monom in monoms:
      m = monom.split()
      coeff, ex = m[0], m[1:]
      d[tuple([int(ei[1:]) for ei in ex])] = int(coeff[1:])
      
   return ring(d)
   
   
   