import inst_utils
import rudy_py_generator
import torch
import numpy as np
import networkx as nx
import os

# Import specific planter functions from chook

import chook.wrappers

# Classes: "SK", "GSET", "CHOOK"

def generate_chook_instance(N, seed, instance_type="WP", alpha=0.5):
    """
    Generate a Chook instance (either Wishart-planted or Tile-planted)
    using the specific planter functions from chook.
    
    Args:
        N (int): Number of variables
        seed (int): Random seed
        instance_type (str): 'WP' for Wishart-planted or 'TP' for Tile-planted
        alpha (float): Alpha parameter for Wishart planting (0 < alpha < 1)
    
    Returns:
        tuple: (ground state energy, J matrix as tensor)
    """
    np.random.seed(seed)
    
    if instance_type == "WP":
        # Use the Wishart planting function
        M = int(N*alpha)

        bonds, gs_energy = chook.wrappers.wishart_planting_wrapper(N,
                             M,
                             discretize_couplers=True,
                             gauge_transform=False,
                             convert_to_hobo=False)
        
    
    elif instance_type == "TP":
        # Calculate dimensions for a square grid
        side_length = int(np.sqrt(N))
        actual_N = side_length * side_length  # Actual number of spins may be slightly different
        
        tile_params = (0.1, 0.1, 0.1)
        # Use the Tile planting function for 2D problems
        bonds, gs_energy = chook.wrappers.tile_planting_wrapper(side_length,
                          2,
                          tile_params,
                          gauge_transform=False,
                          convert_to_hobo=False)
    else:
        raise ValueError(f"Unknown instance type: {instance_type}")
    
    
    J = torch.zeros(N,N)
    for i,j,Jij in bonds:
        J[i,j] = Jij
        J[j,i] = Jij
    
    return 2*gs_energy, J
    

def generate_BA_graph(N, m, seed):
    nx_graph = nx.barabasi_albert_graph(N, m, seed = int(seed))
    adj = nx.adjacency_matrix(nx_graph)
    J = np.array(adj.todense())
    J = np.maximum(J , J.T)
    return J

def get_GSET(N, version, inst):
    
    inst_path = f"./Gset/Gset_{N}_v{version}/G{inst+1}.txt"
    sol_path = "./Gset/SOL.txt"
    sol = np.loadtxt(sol_path)[inst]

    info = np.loadtxt(inst_path, skiprows=1)

    J = np.zeros((N, N))
    for i, j, Jij in info:

        J[int(i)-1,int(j)-1] = J[int(j)-1,int(i)-1] = Jij
    
    E0 = np.sum(J) - sol*4
    print("E0", E0)
    return E0, J

GSET_configs_inst = {800: list(range(0,5)) + list(range(5,10)) + list(range(10,13)) + list(range(13,17)) + list(range(17,21)),\
                     2000: list(range(21, 26)) + list(range(26,31)) + list(range(31,34)) + list(range(34,38)) + list(range(38,42))}
GSET_configs_version = {800: [1]*5 + [2]*5 + [3]*3 + [4]*4 + [5]*4, \
                        2000: [1]*(5+5+3+4+4)}

current_MIS_size = 0

def gen_problem(N, seed, clss):
    global current_MIS_size
    """
    Generate a problem instance.
    
    Args:
        N (int): Problem size
        seed (int): Random seed
        clss (str): Problem class ("SK", "GSET", or "CHOOK")
        
    Returns:
        tuple: (ground state energy, torch tensor of coupling matrix)
    """

    seed = int(seed)

    if(clss == "SK"):
        return 0, inst_utils.get_problem_SK(N, seed)
    
    

    if(clss == "BA"):
        return 0, generate_BA_graph(N, 4, seed)
    elif(clss.startswith("BA")):
        N_1 = int(clss.split("_")[1])
        N_2 = int(clss.split("_")[2])
        rng = np.random.default_rng(seed = seed)
        N = N_1 + int((1 + N_2 - N_1)*rng.random())
        J_ = generate_BA_graph(N, 4, seed)
        J = np.zeros((N_2, N_2))
        J[:N, :N] = J_
        return 0, J
    
    if(clss == "tBA"):
        return 0, generate_BA_graph(N, 4, seed)
    elif(clss.startswith("tBA")):
        if(len(clss.split("_")) == 3):
            N_1 = int(clss.split("_")[1])
            N_2 = int(clss.split("_")[2])
            rng = np.random.default_rng(seed = seed )
            N = N_1 + int((1 + N_2 - N_1)*rng.random())
            J_ = generate_BA_graph(N, 4, seed + 10**4)
            J = np.zeros((N_2, N_2))
            J[:N, :N] = J_
            return 0, J

        if(len(clss.split("_")) == 4):
            N_1 = int(clss.split("_")[1])
            N_2 = int(clss.split("_")[2])
            seed_offset = int(clss.split("_")[3])
            rng = np.random.default_rng(seed = seed + seed_offset )
            N = N_1 + int((1 + N_2 - N_1)*rng.random())
            J_ = generate_BA_graph(N, 4, seed + 10**4 + seed_offset)
            J = np.zeros((N_2, N_2))
            J[:N, :N] = J_
            return 0, J
        
    
    if(clss == "GSET"):
        subc = seed%5
        if(subc == 0):
            return 0, torch.tensor(rudy_py_generator.create_random_graph(N, 10.0, seed))
        if(subc == 1):
            return 0, inst_utils.get_problem_SK(N, seed)*torch.tensor(rudy_py_generator.create_random_graph(N, 10.0, seed))
        if(subc == 2):
            return 0, torch.tensor(rudy_py_generator.create_toroidal_grid_2d(int(np.sqrt(N)), int(np.sqrt(N))))
        if(subc == 3):
            return 0, torch.tensor(rudy_py_generator.graph_union(rudy_py_generator.create_planar_graph(N, 10.0, seed), rudy_py_generator.create_planar_graph(N, 10.0, seed+2)))
        if(subc == 4):
            return 0, inst_utils.get_problem_SK(N, seed)*torch.tensor(rudy_py_generator.graph_union(rudy_py_generator.create_planar_graph(N, 10.0, seed), rudy_py_generator.create_planar_graph(N, 10.0, seed+2)))

    if(clss.startswith("RUDY")):
        
        
        if(len(clss.split("_")) > 1):
            type = clss.split("_")[1]
        if(len(clss.split("_")) > 2):
            density = int(clss.split("_")[2])
        
        graph_type = type[0]
        is_signed = type[1] == "s"



        if(graph_type == "S"):
            J = torch.tensor(rudy_py_generator.create_random_graph(N, density, seed))

        
        if(graph_type == "T"):
            np.random.seed(seed)
            width1 = int(np.random.rand()*np.sqrt(N)*0.5 + np.sqrt(N)*0.5)
            width2 = int(N/width1)

            N_ = width1*width2
            
            J = torch.zeros(N, N)
            J[:N_, :N_] = torch.tensor(rudy_py_generator.create_toroidal_grid_2d(width1, width2))
            

        if(graph_type == "P"):
            J = torch.tensor(rudy_py_generator.graph_union(rudy_py_generator.create_planar_graph(N, density, seed), rudy_py_generator.create_planar_graph(N, density, seed+2)))


        #print("sparsity", torch.sum(np.abs(J))/(N*(N-1)))


        if(is_signed):
            J = J*torch.tensor(inst_utils.get_problem_SK(N, seed))
        
        #print(torch.sum(torch.diag(J)), torch.sum((J - J.T)**2))
        return 0, J


    if(clss == "GSETb"):
        configs_inst = GSET_configs_inst[N]
        configs_version = GSET_configs_version[N]
        
        E0, J = get_GSET(N, configs_version[seed], configs_inst[seed])
        #print("sparsity", np.sum(np.abs(J))/(N*(N-1)))
        # print(np.sum(np.abs(J), axis = 0))
        return E0, torch.tensor(J)
    
    if(clss.startswith("CHOOK")):
        alpha = 1.0
        if(len(clss.split("_")) > 1):
            alpha = float(int(clss.split("_")[1]))/100
        subc = seed % 1
        if(subc == 0):
            # Wishart-planted instance
            return generate_chook_instance(N, seed, "WP", alpha=alpha)
        elif(subc == 1):
            # Tile-planted instance
            return generate_chook_instance(N, seed, "TP")
    
    if(clss.startswith("MIS")):
        N_1 = int(clss.split("_")[1])
        N_2 = int(clss.split("_")[2])
        sdds = clss.startswith("MISsdds")
        J_, N_ = inst_utils.read_Jij_MIS((N_1, N_2), seed, sdds = sdds)
        current_MIS_size = N_
        J = np.zeros((N, N))
        J[:N_, :N_] = J_
        #print(J_)
        h = np.sum(np.array(J_), axis = 0)
        h = h - 0.9
        print("sparsity", np.sum(J)/(N*(N-1)))

        J[N_, :N_] = h
        J[:N_, N_] = h
        return 0, J
    
    if(clss.startswith("tMIS")):
        N_1 = int(clss.split("_")[1])
        N_2 = int(clss.split("_")[2])
        sdds = clss.startswith("tMISsdds")
        J_, N_ = inst_utils.read_Jij_MIS((N_1, N_2), seed, test = True, sdds = sdds)
        current_MIS_size = N_
        J = np.zeros((N, N))
        J[:N_, :N_] = J_
        #print(J_)
        h = np.sum(np.array(J_), axis = 0)
        h = h - 0.9
        print("sparsity", np.sum(J)/(N*(N-1)))

        J[N_, :N_] = h
        J[:N_, N_] = h
        return 0, J

    if(clss.startswith("MC")):
        N_1 = int(clss.split("_")[1])
        N_2 = int(clss.split("_")[2])
        J_, N_ = inst_utils.read_Jij_MIS((N_1, N_2), seed)
        #print(J_)
        J_ = 1 - J_
        J_ = J_ - np.diag(np.diag(J_))
        
        current_MIS_size = N_
        J = np.zeros((N, N))
        J[:N_, :N_] = J_
        #print(J_)
        h = np.sum(np.array(J_), axis = 0)
        h = h - 0.9

        J[N_, :N_] = h
        J[:N_, N_] = h
        return 0, J
    
    if(clss.startswith("tMC")):
        N_1 = int(clss.split("_")[1])
        N_2 = int(clss.split("_")[2])
        J_, N_ = inst_utils.read_Jij_MIS((N_1, N_2), seed, test = True)
        #print(J_)
        J_ = 1 - J_
        J_ = J_ - np.diag(np.diag(J_))
        
        current_MIS_size = N_
        J = np.zeros((N, N))
        J[:N_, :N_] = J_
        #print(J_)
        h = np.sum(np.array(J_), axis = 0)
        h = h - 0.9

        J[N_, :N_] = h
        J[:N_, N_] = h
        return 0, J
    
    # Default fallback
    return 0, torch.tensor(np.random.randn(N, N) / np.sqrt(N))