import numpy as np 
from sympy import * 
import math 
import PyNormaliz
from numba import jit, prange
import psutil 
import multiprocessing as mp
from tqdm import tqdm

class transvectant_class: 
    
    def __init__(self, 
                 left,
                 right,
                 tvec_order,
                 expl_symbol = None, 
                 power = 1, 
                 modulo_prime = 65521,
                 cache = True):
        
        self.left = left 
        self.right = right 
        self.tvec_order = tvec_order
        self.power = power 
        self.modulo_prime = modulo_prime 
        self.expl_symbol = expl_symbol 
        self.order = (self.left.order + self.right.order - 2*self.tvec_order)*self.power 
        self.degree = (self.left.degree + self.right.degree)*self.power 
        self.cache = cache 
        self.x_symbol, self.y_symbol = symbols('x y')

        self.cache_dir = {} # mp.Manager().dict() #{} 
        
    def __pow__(self, power): 

        if power <= 0: raise NotImplementedError
        
        if power == 1: 
            return self 
        
        return transvectant_class(self.left,
                                  self.right,
                                  self.tvec_order, 
                                  expl_symbol = f"{self.expl_symbol}**{power}", 
                                  power = self.power*power, 
                                  modulo_prime = self.modulo_prime, 
                                  cache = True) 
    
    def __mul__(self, other): 
        
        assert self.modulo_prime == other.modulo_prime  
        
        return transvectant_class(self,
                                  other,
                                  0,
                                  expl_symbol = f"{self.expl_symbol}*{other.expl_symbol}", 
                                  modulo_prime = self.modulo_prime, 
                                  cache = True)

    def symbol(self, compact_expression = 0):

        if compact_expression:
            if self.expl_symbol != None: 
                if "None" not in self.expl_symbol:  
                    return self.expl_symbol
            
        return '[({}, {})_{}]^{}'.format(self.left.symbol(compact_expression),
                                         self.right.symbol(compact_expression),
                                         self.tvec_order,
                                         self.power)

    
    @property
    def __cache_key__(self):
        return str(self.id)  

        
    def cache_key(self, *args, **kwargs): 

        instance_data = f"{self.__class__.__name__}-{self.__dict__}"
        
        key = f"{instance_data}-{args}-{kwargs}"
        
        hasher = xxhash.xxh64()
        hasher.update(key.encode())
        hashed_key = hasher.hexdigest()
        
        return hashed_key

    #@memory.cache 
    #@ttl_cache(maxsize=None, ttl = 15*60)
    #@disk_cache_method
    #@cache.memoize() 
    #@lru_cache(maxsize=None)
    # @cache_ram_heirarchal
    def get_tvec_poly_dir(self,
                          random_seed):

        ck = random_seed # self.cache_key(random_seed)

        if ck in self.cache_dir : return self.cache_dir[ck]
        
        tvec = transvectant_mod(self.left.evaluate(random_seed = random_seed),# cache_dir = cache_dir),
                                    self.right.evaluate(random_seed = random_seed),# cache_dir = cache_dir),
                                    self.tvec_order,
                                    order1= self.left.order,
                                    order2 = self.right.order)**self.power 
        result = tvec.as_dict() 
        
        if self.cache: 
            # update cache 
            self.cache_dir[ck] = result 
        
        return result 

    def evaluate(self,
                 coefs_dir = None,
                 cache_dir = None, 
                 random_seed = -1): 
        
        if random_seed != -1: 
            
            tvec_dir = self.get_tvec_poly_dir(random_seed) 
            
            return Poly.from_dict(tvec_dir,
                                  gens = (self.x_symbol,
                                          self.y_symbol),
                                  modulus = self.modulo_prime)
    
        return  transvectant_mod(self.left.evaluate(coefs_dir = coefs_dir),
                                self.right.evaluate(coefs_dir = coefs_dir),
                                self.tvec_order,
                                order1= self.left.order,
                                order2 = self.right.order)**self.power 
    



class binary_form_class:

    def __init__(self,
                 bf_order,
                 order,
                 coef_symbol, 
                 coef_list = None, 
                 symbol_str = None, 
                 modulo_prime = 65521,
                 power = 1):
        
        self.bf_order = bf_order 
        self.order = order*power  
        self.power = power 
        self.degree = 1*self.power 
        self.coef_symbol = coef_symbol
        self.modulo_prime = modulo_prime 
        self.symbol_str = symbol_str
        self.expl_symbol = str(symbol_str) + f"**{power}" 
        self.coef_list = coef_list if coef_list != None else None 
        self.pre_computed_bf = self._evaluate_form() 
        self.x_symbol, self.y_symbol = symbols('x y')
        self.cache_dir = {} 
        
    def __pow__(self,
                power):
        
        if power <= 0: 
            raise NotImplementedError
            
        return binary_form_class(self.bf_order,
                                 self.order,
                                 self.coef_symbol, 
                                 coef_list = self.coef_list, 
                                 symbol_str = self.symbol_str, 
                                 modulo_prime = self.modulo_prime, 
                                 power = self.power*power)

    def __mul__(self,
                other): 
        
        return transvectant_class(self, other, 0)

    def symbol(self, compact_expression = 1):
        return symbols(self.symbol_str)**self.power 

    def _evaluate_form(self): 
        
        bf_eval = Poly(binary_form(self.bf_order,
                                    self.coef_symbol, 
                                    coef_symbol_list = self.coef_list)**self.power,
                                    modulus = self.modulo_prime)
        return bf_eval 

    #@ttl_cache(maxsize=None, ttl = 60*60)
    # @memory.cache 
    # @disk_cache_method

    @property
    def __cache_key__(self):
        return str(self.id)  
    
    ## @ttl_cache(maxsize=None, ttl = 15*60)
    # @cache.memoize()
    
    def cache_key(self, *args, **kwargs): 

        instance_data = f"{self.__class__.__name__}-{self.__dict__}"
        
        key = f"{instance_data}-{args}-{kwargs}"
        
        hasher = xxhash.xxh64()
        hasher.update(key.encode())
        hashed_key = hasher.hexdigest()
        
        return hashed_key
    
    
    #@disk_cache_method
    # @cache.memoize()
    #@lru_cache(maxsize = None)
    # @cache_ram_heirarchal
    def get_poly_dir(self, 
                     random_seed, 
                     cache_dir = None):


        ck = random_seed # self.cache_key(random_seed)
        
        if ck in self.cache_dir : return self.cache_dir[ck]
        
        bf_poly = self.pre_computed_bf
        
        seed_unique_to_bf = random_seed + ord(str(self.symbol_str))
        
        np.random.seed(seed_unique_to_bf)
        
        coef_dir_np = create_coef_dir(self.bf_order,
                                   self.coef_symbol,
                                   prime = self.modulo_prime )

        bf_eval_subs = Poly(bf_poly.subs(coef_dir_np),
                    modulus = self.modulo_prime) 

        result = bf_eval_subs.as_dict() 
        self.cache_dir[ck] = result 
        return result 
    
    def evaluate(self,
                 coefs_dir = None,
                 random_seed = -1, 
                 cache_dir = None):

        bf_eval = self.pre_computed_bf 
        
        if random_seed != -1 : 
            
            poly_dict = self.get_poly_dir(random_seed)#, cache_dir = cache_dir)
            
            x, y = symbols('x y')
            
            return Poly.from_dict(poly_dict,
                                  gens = (self.x_symbol,
                                          self.y_symbol),
                                  modulus = self.modulo_prime)
                                
        if coefs_dir != None: 
            
            return Poly(bf_eval.subs(coefs_dir),
                        modulus = self.modulo_prime) 
        else: 
            return bf_eval 

def binary_form(n, coef_symbol = 'a', coef_symbol_list = None):
    
    if coef_symbol_list == None: 
        sn_coefs = symbols(coef_symbol + '0:' + str(n+1)) 
    else: 
        assert len(coef_symbol_list.strip().split(", ")) == n + 1 
        sn_coefs =  symbols(coef_symbol_list)
        
    x, y = symbols('x y')
    
    bform = 0
    
    for i in range(n+1): 
        bform += sn_coefs[i]*math.comb(n, i)*x**(n-i)*y**(i)
    return bform 

def transvectant_mod(poly1, poly2, r, order1 = None, order2 = None): 
        '''
        (p1, p2)_(r)
        '''

        if r == 0: return poly1*poly2
        
        x, y = symbols('x y')

        # print(poly1)
    
        transv = Poly(0,x, y, domain = poly1.domain)

        if min(order1, order2) < r: return transv 
    
        for i in range(r+1): 
    
            tvec_coef = (-1)**i * math.comb(r, i)
            t3 = poly1.diff(*[x for _ in range(r-i)], *[y for _ in range(i)])
            t4 = poly2.diff(*[x for _ in range(i)], *[y for _ in range(r-i)])
            transv += (t3*t4).mul(tvec_coef)

        
    
        c1 = factorial(order1 - r)/factorial(order1)*factorial(order2 - r)/factorial(order2)
        #print(order1, r, order2, "order1, 2, r")
        #print("c1: ", c1)
        #try:
    
        num, den = transv.div(c1.denominator)

        #except: 
        #    print(c1)
        #    1/0
        # there is no remainder when div by scalar 
        # c1 is a fraction of type 1/D
    
        assert den == 0 
        assert c1.numerator == 1 

        return num

def solve_boundary_lattice_points_v2(eq1, eq2): 
    '''
    solves for positive solutions of 
    eq1 = 0
    eq2 = 0 
    '''
    assert len(eq1) == len(eq2)
    
    ineqs = np.eye(len(eq1)-1, dtype = np.int32).tolist() 
    inhom_eqs = np.array([eq1, eq2], dtype = np.int32).tolist() 
    
    polyhedron = PyNormaliz.Cone(inhom_equations = inhom_eqs,
                                 inequalities = ineqs)
    
    lattice_points = np.array(polyhedron.LatticePoints())
    
    if len(lattice_points): 
        if np.all(lattice_points[:, -1] == 1): return lattice_points
        else: 
            return np.array([i for i in lattice_points if i[-1] == 1])
    
    return lattice_points
    

def do(a):
    return (a.degree, a.order)

def solve_int_prog(d, k, d_arr, k_arr, verbose = False): 
    
    #if sum(k_arr) == 0: 
    #    if verbose: print('f(invs) skipping')
    #    return [] 
        
    d_arr.extend([-d])
    k_arr.extend([-k])

    # basis = solve_boundary_lattice_points(d_arr, k_arr)
    basis = solve_boundary_lattice_points_v2(d_arr, k_arr)
    
    if len(basis) == 0: return basis 
    assert np.all(basis[:,-1] == 1) 
    red_basis = [i[:-1] for i in basis if i[-1] == 1]
    return red_basis 

def create_fun_from_basis(all_polys, basis): 
    poly = None
    for i in range(len(basis)): 
        if basis[i] == 0: 
            continue 
        if poly == None:
            poly = all_polys[i]**basis[i]
        else:
            poly = poly*all_polys[i]**basis[i]
    return poly
    
   
def generate_homo_space(degree, order, all_polys, verbose = False): 

    all_degrees = [] 
    all_orders = [] 

    for poly in all_polys:
        
        d, k = poly.degree, poly.order 
        
        all_degrees.append(d)
        all_orders.append(k)
    
    if verbose: print('solving int prog')
    
    basis_for_homo = solve_int_prog(degree,
                                    order,
                                    all_degrees,
                                    all_orders,
                                    verbose = verbose)
    homo_arr = [] 

    if verbose: print('int prog solved, has ',
                      len(basis_for_homo),
                      'entries. ')

    if verbose: print('the basis are {}.'.format(basis_for_homo))
        
    for basis in basis_for_homo: 
        
        homo_entry = create_fun_from_basis(all_polys,
                                           basis)
        homo_arr.append(homo_entry)   
        
    return homo_arr, basis_for_homo 

from numba import int64, boolean 

def create_diophantine_system(min_set_1, min_set_2, only_invariants = False): 
    
    o1 = [i.order for i in min_set_1 if i.order != 0]
    o2 = [i.order for i in min_set_2 if i.order != 0]

    if not only_invariants: 
        eq1 = np.zeros(len(o1) + len(o2) + 3, dtype = np.int32)
        eq2 = np.zeros(len(o1) + len(o2) + 3, dtype = np.int32)
    else: 
        eq1 = np.zeros(len(o1) + len(o2) + 1, dtype = np.int32)
        eq2 = np.zeros(len(o1) + len(o2) + 1, dtype = np.int32)
        
    eq1[:len(o1)] = o1 
    eq2[len(o1):len(o1) + len(o2)] = o2 

    if not only_invariants:     
        eq1[-3:] = [-1, 0, -1]
        eq2[-3:] = [0, -1, -1]
    else: 
        eq1[-1] = -1 
        eq2[-1] = -1 
        
    
    return np.array([eq1, eq2]).tolist() 

@jit(nopython=True)
def modular_inverse(a, p):
    """Compute the modular inverse
    of a under mod p using
    Extended Euclidean Algorithm."""
    t, new_t = 0, 1
    r, new_r = p, a

    while new_r != 0:
        quotient = r // new_r
        t, new_t = new_t, t - quotient * new_t
        r, new_r = new_r, r - quotient * new_r

    if r > 1:
        raise ValueError("a is not invertible")
    if t < 0:
        t = t + p

    return t


@jit(int64[:](int64[:, :], int64, boolean), 
     nopython = True,
     parallel = True)
def fast_rref(matrix,
              mod, 
              rref = False):
    '''
    Reduced Row Echelon Form (RREF)
    using JIT compilation. 

    Inputs: 
        matrix: [n x m] size matrix 
        mod: modulo prime 
        rref: True/False 
            If false, the reduced 
            echelon form is evaluated 
            instead of rref 

    Outputs: 
        Li rows of the matrix 

    Notes: 
        In most cases, this is 
        only used to get the rank 
        of the matrix, in those cases, 
        a complete rref is not necessary 
        and a ref is sufficient. so 
        rref is false by default, and can be 
        switched on if the application 
        requires it. 
    '''
    
    rows, cols = matrix.shape
    
    rank = 0
    
    row_order = np.arange(rows)
    
    for col in range(cols): 

        pivot_row = -1 
        
        for row in range(rank, rows):
            if matrix[row, col] != 0:
                pivot_row = row
                break
        
        if pivot_row == -1: continue 
            
        if pivot_row != rank:
            
            matrix[rank], matrix[pivot_row] = matrix[pivot_row].copy(), matrix[rank].copy() 
            row_order[rank], row_order[pivot_row] = row_order[pivot_row], row_order[rank]
        
        pivot = matrix[rank, col]
                
        inv_pivot = modular_inverse(pivot, mod)
        
        if rref: 
            start = 0 
        else: 
            start = rank + 1 

        matrix[rank] = matrix[rank]*inv_pivot % mod

        
        for row in range(start, rows):

            if row == rank: continue 
            
            factor = matrix[row][col] % mod
            
            matrix[row] = (matrix[row] - factor * matrix[rank]) % mod

            
        rank += 1
                        
    return row_order[:rank]


def hilbert_basis_to_tvec(poly_list_1, poly_list_2, basis_1, basis_2, u, v, r): 
    # print('*')
    U = None 
    for u_base_index in range(len(basis_1)): 
        if basis_1[u_base_index] == 0: continue 
        else: 
            if U == None:
                U = poly_list_1[u_base_index]**basis_1[u_base_index]
            else: U*= poly_list_1[u_base_index]**basis_1[u_base_index]
    # print('U: ', U.symbol(1))
    V = None 

    for v_base_index in range(len(basis_2)): 
        if basis_2[v_base_index] == 0: continue 
        else: 
            if V == None: V = poly_list_2[v_base_index]**basis_2[v_base_index]
            else: V*= poly_list_2[v_base_index]**basis_2[v_base_index]

    # print('V: ', V.symbol(1))
    
    if U == None: return V 
    if V == None: return U 
        
    return transvectant_class(U, V, r)

def create_generating_set_U_plus_V(hilbert_basis, min_set_U, min_set_V): 

    invs_U = [i for i in min_set_U if i.order == 0]
    invs_V = [i for i in min_set_V if i.order == 0]

    generating_set = invs_U[:] 
    generating_set.extend(invs_V)

    U_covariants = [i for i in min_set_U if i.order != 0]
    V_covariants = [i for i in min_set_V if i.order != 0] 

    num_Us = len(U_covariants)
    num_Vs = len(V_covariants)
    
    for base in hilbert_basis: 
        if sum(base) == 0: print('ignoring an all 0 base'); continue 
        tvec = hilbert_basis_to_tvec(U_covariants, V_covariants, base[:num_Us], base[num_Us: num_Vs + num_Us], *base[-3:])
        generating_set.append(tvec)
    return generating_set 



def multi_var_poly_coeffs(poly, order):
    x, y = symbols('x y')
    coeffs = [poly.expr.coeff(x, order  - i).coeff(y, i) for i in range(order + 1)]
    return coeffs
    
def evaluate_tvec_at_seed(tvec_list, seed, mod_prime): 
        
    eval_array = [[] for i in tvec_list] 
    tvecs_order = tvec_list[0].order 
    for tvec_index, tvec in enumerate(tvec_list): #, desc = 'Tvecs at n forms: '): 
        
        eval_tvec_at_seed = tvec.evaluate(random_seed = seed)

        tvec_coeffs = multi_var_poly_coeffs(eval_tvec_at_seed,
                                            tvecs_order)
        
        eval_array[tvec_index].extend(tvec_coeffs)
    
    return eval_array 
    

def process_group(p_values,
                  tvec_list,
                  mod_prime):
    
    results = []
    
    for p in p_values:
        results.append(evaluate_tvec_at_seed(tvec_list, p, mod_prime))
    #print(np.array(results).shape)
    if len(results):
        return results
    else:
        return None 

def reshape_appropriately(results):
    '''
    Reshaping the results appropriately
    from the pooled processes to be 
    later used for evaluating transvectants 
    '''
    num_tvecs = len(results[0][0])
    result1 = np.vstack(results)
    reshaped_arrays = [arr.reshape(num_tvecs, -1) for arr in result1]
    result = np.stack(reshaped_arrays, axis=-1)
    final_result = result.reshape(num_tvecs, -1)
    return final_result 


def evaluate_tvecs_at_random_seeds(tvec_list,
                                   num_forms,
                                   processes_pool = None, 
                                   num_procs = 1, 
                                   modulo_prime = 65521):

    
    p_values = range(1, num_forms + 1)
    grouped_p_values = [[] for _ in range(num_procs)]

    if processes_pool == None: 
            processes_pool = mp.Pool(processes=num_procs)
    print(f"Available Memory: {psutil.virtual_memory().available/1024**3}")

    for p in p_values:
        grouped_p_values[p % num_procs].append(p)

    # print('tvec list cache dir before starmap: ', tvec_list[0].cache)
    
    _results = processes_pool.starmap(process_group, zip(grouped_p_values,
                                              [tvec_list for i in range(num_procs)],
                                              [modulo_prime for i in range(num_procs)]))
    
    # print('tvec list cache dir : ', tvec_list[0].cache)
    print(f"Available Memory: {psutil.virtual_memory().available/1024**3}")

    processes_pool.close() 
    
    results = [ i for i in _results if i != None]

    eval_array = reshape_appropriately(results)
    
    return np.array(np.array(eval_array) % modulo_prime  , dtype = np.int64)


