from tools import *
from sympy import * 
import PyNormaliz 
import matplotlib.pyplot as plt 
import math
from importlib import reload 
from tqdm import tqdm 
from sage.all import macaulay2 
from importlib import reload 
from tools import * 


t, z = symbols('t z')

x, y = symbols('x y')
prime_number = 65521 

f = binary_form_class(6, 6, 'a', symbol_str = 'f', modulo_prime=prime_number)

h2_0 = transvectant_class(f, f, 6, modulo_prime=prime_number, expl_symbol = 'f_1')
h2_4 = transvectant_class(f, f, 4, modulo_prime=prime_number, expl_symbol = 'f_2')
h2_8 = transvectant_class(f, f, 2, modulo_prime=prime_number, expl_symbol = 'f_3')


h3_2 = transvectant_class(h2_4, f, 4, modulo_prime=prime_number, expl_symbol = 'f_4')
h3_6 = transvectant_class(h2_4, f, 2, modulo_prime=prime_number, expl_symbol = 'f_5')
h3_8 = transvectant_class(h2_4, f, 1, modulo_prime=prime_number, expl_symbol = 'f_6')
h3_12 = transvectant_class(h2_8, f, 1, modulo_prime=prime_number, expl_symbol = 'f_7')


h4_0 = transvectant_class(h2_4, h2_4, 4, modulo_prime=prime_number, expl_symbol = 'f_8')
h4_4 = transvectant_class(h3_2, f, 2, modulo_prime=prime_number, expl_symbol = 'f_9')
h4_6 = transvectant_class(h3_2, f, 1, modulo_prime=prime_number, expl_symbol = 'f_10')
h4_10 = transvectant_class(h2_8, h2_4, 1, modulo_prime=prime_number, expl_symbol = 'f_11')


h5_2 = transvectant_class(h2_4, h3_2, 2, modulo_prime=prime_number, expl_symbol = 'f_12')
h5_4 = transvectant_class(h2_4, h3_2, 1, modulo_prime=prime_number, expl_symbol = 'f_13')
h5_8 = transvectant_class(h2_8, h3_2, 1, modulo_prime=prime_number, expl_symbol = 'f_14')


h6_0 = transvectant_class(h3_2, h3_2, 2, modulo_prime=prime_number, expl_symbol = 'f_15')
h6_6_1 = transvectant_class(h3_8, h3_2, 2, modulo_prime=prime_number, expl_symbol = 'f_16')
h6_6_2 = transvectant_class(h3_6, h3_2, 1, modulo_prime=prime_number, expl_symbol = 'f_17')

h7_2 = transvectant_class(f, h3_2**2, 4, modulo_prime=prime_number, expl_symbol = 'f_18')
h7_4 = transvectant_class(f, h3_2**2, 3, modulo_prime=prime_number, expl_symbol = 'f_19')

h8_2 = transvectant_class(h2_4, h3_2**2, 3, modulo_prime=prime_number, expl_symbol = 'f_20')

h9_4 = transvectant_class(h3_8, h3_2**2, 4, modulo_prime=prime_number, expl_symbol = 'f_21')

h10_0 = transvectant_class(h3_2**3, f, 6, modulo_prime=prime_number, expl_symbol = 'f_22')
h10_2 = transvectant_class(h3_2**3, f, 5, modulo_prime=prime_number, expl_symbol = 'f_23')

h12_2 = transvectant_class(h3_8, h3_2**3, 6, modulo_prime=prime_number, expl_symbol = 'f_24')

h15_0 = transvectant_class(h3_8, h3_2**4, 8, modulo_prime=prime_number, expl_symbol = 'f_25')

s6_min_set = [h2_0, h4_0, h6_0, h10_0, h15_0, h3_2, h2_4, h4_4, f, h3_6, h4_6, h5_4, h2_8, h6_6_1, h6_6_2, h3_8, h7_4, h5_2, h7_2, h9_4, h12_2, h10_2, h8_2, h5_8, h4_10, h3_12]


# s4 basis 

b1 = binary_form_class(4, 4, 'm',
                       modulo_prime = prime_number,
                       symbol_str = 'p')

b2 = binary_form_class(4, 4, 'n',
                       modulo_prime = prime_number,
                       symbol_str = 'q')

b3 = transvectant_class(b1, b1, 4, modulo_prime = prime_number, expl_symbol = 'z_1')
b4 = transvectant_class(b2, b2, 4, modulo_prime = prime_number, expl_symbol = 'z_2')
b5 = transvectant_class(b1, b2, 4, modulo_prime = prime_number, expl_symbol = 'z_3')
b6 = transvectant_class(b1, b2, 3, modulo_prime = prime_number, expl_symbol = 'z_4')
b7 = transvectant_class(b1, b1, 2, modulo_prime = prime_number, expl_symbol = 'z_5')
b8 = transvectant_class(b2, b2, 2, modulo_prime = prime_number, expl_symbol = 'z_6')
b9 = transvectant_class(b1, b2, 2, modulo_prime = prime_number, expl_symbol = 'z_7')
b10 = transvectant_class(b1, b2, 1, modulo_prime = prime_number, expl_symbol = 'z_8')
b11 = transvectant_class(b1, b7, 4, modulo_prime = prime_number, expl_symbol = 'z_9')
b12 = transvectant_class(b2, b8, 4, modulo_prime = prime_number, expl_symbol = 'z_10')
b13 = transvectant_class(b1, b8, 4, modulo_prime = prime_number, expl_symbol = 'z_11')
b14 = transvectant_class(b2, b7, 4, modulo_prime = prime_number, expl_symbol = 'z_12')

b15 = transvectant_class(b1, b8, 3, modulo_prime = prime_number, expl_symbol = 'z_13')
b16 = transvectant_class(b2, b7, 3, modulo_prime = prime_number, expl_symbol = 'z_14')
b17 = transvectant_class(b1, b8, 2, modulo_prime = prime_number, expl_symbol = 'z_15')
b18 = transvectant_class(b2, b7, 2, modulo_prime = prime_number, expl_symbol = 'z_16')
b19 = transvectant_class(b1, b7, 1, modulo_prime = prime_number, expl_symbol = 'z_17')
b20 = transvectant_class(b2, b8, 1, modulo_prime = prime_number, expl_symbol = 'z_18')
b21 = transvectant_class(b1, b8, 1, modulo_prime = prime_number, expl_symbol = 'z_19')
b22 = transvectant_class(b2, b7, 1, modulo_prime = prime_number, expl_symbol = 'z_20')
b23 = transvectant_class(b7, b8, 4, modulo_prime = prime_number, expl_symbol = 'z_21')
b24 = transvectant_class(b7, b8, 3, modulo_prime = prime_number, expl_symbol = 'z_22')
b25 = transvectant_class(b19, b2, 4, modulo_prime = prime_number, expl_symbol = 'z_23')
b26 = transvectant_class(b1, b20, 4, modulo_prime = prime_number, expl_symbol = 'z_24')
b27 = transvectant_class(b1**2, b20, 6, modulo_prime = prime_number, expl_symbol = 'z_25')
b28 = transvectant_class(b19, b2**2, 6, modulo_prime = prime_number, expl_symbol = 'z_26')

s4s4_min_set = [b3, b4, b5, b11, b12, b13, b14, b23, b1, b2, b6, b7, b8, b9, b10, b15, b16, b17, b18, b19, b20, b21, b22, b24, b25, b26, b27, b28]

degrees_s6 = [] 
degrees_s4s4 = [] 

orders_s6 = [] 
orders_s4s4 = [] 

for U in s6_min_set:
    d, o = do(U)
    degrees_s6.append(d)
    orders_s6.append(o)
    
for V in s4s4_min_set:
    d, o = do(V)
    degrees_s4s4.append(d)
    orders_s4s4.append(o)

A = create_diophantine_system(s6_min_set, s4s4_min_set)
C = PyNormaliz.Cone(equations = A )
gen_set_basis = C.HilbertBasis()

print(len(gen_set_basis))


gen_set = create_generating_set_U_plus_V(gen_set_basis,
                                         s6_min_set,
                                         s4s4_min_set)

print(f"There are {len(gen_set_basis)} polynomisl in the Cov Algebra of S6 + S4+S4.")

max_degree = 3

gens_wo_invs = [i for i in gen_set if (i.order > 0 and i.degree <= max_degree) ]
inv_polys_degree_lt_degree = [i for i in gen_set if i.degree <= max_degree and i.order == 0]

print("Total Number of Invariants: ", len(inv_polys_degree_lt_degree))

gen_set_polys_ltd = [] 
for degree in range(1, max_degree + 1): 
    gen_set_polys_d, b = generate_homo_space(degree, 8, gens_wo_invs)
    gen_set_polys_ltd.extend(gen_set_polys_d)
    print(f"Degree - {degree}. Num Polys  - {len(gen_set_polys_d)}")




# create the base field and the base rings
macaulay2('K = ZZ/65521;')
macaulay2('R = K[a0, a1, a2, a3, a4, a5, a6, m0, m1, m2, m3, m4, n0, n1, n2, n3, n4, x, y];')

# load invariants 
ind =  1 
for poly in inv_polys_degree_lt_degree: 
    inv_poly = f"I{ind} = " +  str(poly.evaluate().as_expr()).replace("**", "^") + ";" 
    ind += 1 
    print(inv_poly)
    macaulay2(inv_poly) 

# In Linear case, no degree 1 invariants are found. Nature of harmonic tensors. 

# load module generators  
ind =  1 
for poly in gen_set_polys_ltd: 
    cov_poly = f"f{ind} = " +  str(poly.evaluate().as_expr()).replace("**", "^") + ";" 
    ind += 1 
    print(cov_poly)
    macaulay2(cov_poly) 


def generate_ideal_string(n):

    if n <= 0:
        return "J = ideal();"
        
    generators = [f"I{i}" for i in range(1, n+1)]
    ideal_string = "J = ideal(" + ", ".join(generators) + ");"
    
    return ideal_string

n = len(inv_polys_degree_lt_degree) 
cmd = generate_ideal_string(n)
print(cmd) 
macaulay2(cmd) 

gb_max_degree = 3
macaulay2(f"time gb (J, DegreeLimit=>{gb_max_degree})")
macaulay2(f'G = gb (J, DegreeLimit => {gb_max_degree});')


# fi %  J 
for i in tqdm(range(1, len(gen_set_polys_ltd) + 1)): 
    cmd = f"mf{i} = f{i} % G"
    macaulay2(cmd)

# create a list of all fi % J 

cmd = "F = {{"
for i in range(1, len(gen_set_polys_ltd) + 1):
    cmd += f"mf{i}, "
cmd = cmd[:-2] + "}}"

print(cmd)

# select only non zero polys 

macaulay2(cmd)
macaulay2("nonZeroF = select(flatten F, e -> e!=0);")
macaulay2("# nonZeroF")


def generate_mons_string(n):
    terms = []
    for i in range(n, -1, -1):
        if i == n:
            terms.append(f"x^{n}")
        elif i == 0:
            terms.append(f"y^{n}")
        else:
            terms.append(f"x^{i}*y^{n-i}")
    mons_str = "mons = {" + ", ".join(terms) + "}"
    return mons_str

def generate_subList_cmd(var_prefixes, degrees):

    terms = []
    idx = 0
    for prefix, deg in zip(var_prefixes, degrees):
        for i in range(deg + 1):
            terms.append(f"{prefix}{i} => rand_{idx}")
            idx += 1
    cmd = "subList = {" + ", ".join(terms) + "};"
    return cmd

cmd = generate_mons_string(8)
print(cmd)
macaulay2(cmd)

numCoeffs = int(str(macaulay2("numgens R"))) - 2 
print(f"There are {numCoeffs} coefficients in each polynomial.")

matrix = [] 

subListCmd = generate_subList_cmd(["a", "m", "n"], [6, 4, 4])
numPoints = int(str(macaulay2("# nonZeroF")))
for pt in tqdm(range(numPoints)): 
    macaulay2(f"rand = apply({numCoeffs}, i -> random(1, 65521));")
    macaulay2(subListCmd)
    macaulay2("Fsub = apply(nonZeroF, e -> substitute(e, subList))")
    macaulay2("evalMatrix = apply(Fsub, fsub -> flatten entries last coefficients(fsub, Monomials => mons))")
    mat = np.array(macaulay2("matrix evalMatrix").sage(), dtype = int)
    if len(matrix) == 0: 
        matrix = mat 
    else:
        matrix = np.concatenate((matrix, mat), axis=1)
print(matrix.shape) 

min_rows = fast_rref(np.array(matrix, dtype = int), 65521, False)

print(f"There are {len(min_rows)} R-linearly independent entries.")

min_polys = [gen_set_polys_ltd[i] for i in min_rows]

### Converting to Tensors etc. 
[i.symbol() for i in min_polys]