import torch
import numpy as np
import instance_gen
from inst_utils import get_problem, get_GE, save_GE
import instance_gen

### This file contains info about instances used including type of instance, problem size, number of time steps and type of reward

#not importnat 
meta_M = 16

### IMPORTANT PARAMETERS ###

## N (problem size is 1st entry in config tuple)
## T (number of time steps is 2nd entry in config tuple)
## rho (type of reward is 3rd entry in config tuple)
 
# rho = -1 -> success rate
# rho >0 -> reward = (E/E0)^{rho}   (soft success rate)

### config_list -> types of problems generated for training

#config_list = [(100,100,-1,"SK"), (100,100,-1,"GSET"), (100,100,-1,"CHOOK"), (100,100,3,"SK"), (100,100,3,"GSET"), (100,100,3,"CHOOK")]


### (N, T, rho, inst_type)
config_list = [(144, 100, 10, "SK"),(144, 100, 10, "CHOOK")]

config_list = [(64, 100, -1, "SK"),(64, 100, -1, "CHOOK")]


config_list = [(100, 100, -1, "SK")]

config_list = [(50, 50, -1, "SK"),(100, 100, -1, "SK"),(200, 200, -1, "SK"),(400, 400, -1, "SK")]

config_list = [(50, 50, -1, "SK"),(100, 100, -1, "SK"),(200, 200, -1, "SK"),(400, 400, -1, "SK"),(600, 600, -1, "SK"),(800, 800, -1, "SK")]



# config_list = [(200, 200, -1, "SK"),(800, 800, -1, "SK"),(800, 800, -1, "SK")]

# config_list = [(100, 100, -1, "SK"), (800, 800, -1, "SK"),  (800, 800, -1, "SK"),  (800, 800, -1, "SK")]
# config_list = [(100, 100, -1, "SK")]
# 
# config_list = [(600, 600, 10, "SK"),(600, 600, 20, "SK"),(800, 800, -1, "SK")]
# 
# 
# config_list = [(50,50, 10, "SK")]
# 
config_list = [(100,100,-1,"SK"),  (800,1500, -1, "SK"), (800,1500, -1, "SK"), (800,1500, -1, "SK")]





inst_seed_numb = 100

inst_seed_offset = 0


#instance logic


### Number of instances used in training (currently 1 for simplicity)
numb_inst_config = 100


### meta instance info and instance norm calculated here
meta_info = torch.zeros(numb_inst_config, meta_M)

inst_norm = torch.zeros(numb_inst_config)


inst_N = []
inst_T = []

inst_clss = []
N_all = set()
T_all = set()
clss_all = set()


def set_instance_group(group_name):
    global config_list, meta_info, inst_norm, numb_inst_config, rew_inst, obj_inst, obj_t10_inst, obj_max_inst, inst_seed_offset
    if(group_name == "SK_many"):
        config_list = [(50, 50, -1, "SK"),(100, 100, -1, "SK"),(200, 200, -1, "SK"),(400, 400, -1, "SK"),(600, 600, -1, "SK"),(800, 800, -1, "SK")]
        numb_inst_config = 100

    if(group_name.startswith("SK_05T_N_")):
        N = int(group_name.split("_")[3])
        config_list = [(N,int(0.5*N), -1, "SK")]
        numb_inst_config = 100
    

    if(group_name.startswith("SK_1T_N_")):
        N = int(group_name.split("_")[3])
        config_list = [(N, N, -1, "SK")]
        numb_inst_config = 100


    if(group_name.startswith("SK_SINGLE_1T_N_")):
        N = int(group_name.split("_")[4])
        if(len(group_name.split("_")) > 5):
            inst_seed_offset = int(group_name.split("_")[5])
            print("iso", inst_seed_offset)
        config_list = [(N, N, -1, "SK")]
        numb_inst_config = 1
    
    if(group_name.startswith("SK_2T_N_")):
        N = int(group_name.split("_")[3])
        config_list = [(N, 2*N, -1, "SK")]
        numb_inst_config = 100
    
    if(group_name.startswith("BA_1T_N_")):
        N = int(group_name.split("_")[3])
        if(len(group_name.split("_")) == 4):
            config_list = [(N, N, -1, "BA")]
        elif(len(group_name.split("_")) == 5):
            N2 = int(group_name.split("_")[4])
            config_list = [(N2, N2, -1, f"BA_{N}_{N2}")]
        
        numb_inst_config = 100

    if(group_name.startswith("BA_1T_N_")):
        N = int(group_name.split("_")[3])
        if(len(group_name.split("_")) == 4):
            config_list = [(N, N, -1, "BA")]
        elif(len(group_name.split("_")) == 5):
            N2 = int(group_name.split("_")[4])
            config_list = [(N2, N2, -1, f"BA_{N}_{N2}")]
        
        numb_inst_config = 100

    if(group_name.startswith("RUDY")):
        
        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)


        N = int(group_name.split("_")[2])

        type = group_name.split("_")[3]

        group_name_ = f"RUDY_{type}"

        
        if(len(group_name.split("_")) > 4):
            density = int(group_name.split("_")[4])
            group_name_ = f"RUDY_{type}_{density}"
        
        config_list = [(N, int(T_factor*N), -1, group_name_)]
        
        numb_inst_config = 100
    


    if(group_name.startswith("tBA_1T_N_")):
        N = int(group_name.split("_")[3])
        if(len(group_name.split("_")) == 4):
            config_list = [(N, N, -1, "tBA")]
        elif(len(group_name.split("_")) == 5):
            N2 = int(group_name.split("_")[4])
            config_list = [(N2, N2, -1, f"tBA_{N}_{N2}")]
        elif(len(group_name.split("_")) == 6):
            N2 = int(group_name.split("_")[4])
            start = int(group_name.split("_")[5])
            config_list = [(N2, N2, -1, f"tBA_{N}_{N2}_{start}")]
        
        numb_inst_config = 1000
    
    print("setting_instance_group",group_name )

    if(group_name.startswith("CHOOK")):
        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)

        N = int(group_name.split("_")[3])
        config_list = [(N, N, -1, "CHOOK")]
        if(len(group_name.split("_")) > 3):
            config_list = [(N, int(T_factor*N), -1, "CHOOK_" + group_name.split("_")[4])]
        
        numb_inst_config = 100


    if(group_name.startswith("GSETb")):
        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)

        N = int(group_name.split("_")[3])
        config_list = [(N, N, -1, "GSETb")]

        if(len(group_name.split("_")) > 4):
            v_factor = group_name.split("_")[4]
            version = int(v_factor[1:])

        
        n_config = {800: 21, 2000: 21}
        v_config = {800: [(0,5),(5,10),(10,13),(13,17),(17,21)], 2000: [(0,5),(5,10),(10,13),(13,17),(17,21)]}
        
        numb_inst_config = n_config[N]

        if(len(group_name.split("_")) > 4):
            numb_inst_config = v_config[N][version][1] - v_config[N][version][0]
            inst_seed_offset = v_config[N][version][0]
        

    if(group_name.startswith("MIS")):

        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)


        N_1 = int(group_name.split("_")[2])
        N_2 = int(group_name.split("_")[3])

        N = max(N_1, N_2) + 1

        name = group_name.split("_")[0]

        

        config_list = [(N, int(T_factor*N), -1, f"{name}_{N_1}_{N_2}")]

        numb_inst_config = 100
    
    if(group_name.startswith("tMIS")):

        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)


        N_1 = int(group_name.split("_")[2])
        N_2 = int(group_name.split("_")[3])

        name = group_name.split("_")[0]

        N = max(N_1, N_2) + 1
        config_list = [(N, int(T_factor*N), -1, f"{name}_{N_1}_{N_2}")]
        
        numb_inst_config = 1000

        if(name.startswith("tMISsdds")):
            numb_inst_config = 1000
    
    if(group_name.startswith("MC")):

        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)


        N_1 = int(group_name.split("_")[2])
        N_2 = int(group_name.split("_")[3])

        N = max(N_1, N_2) + 1
        config_list = [(N, int(T_factor*N), -1, f"MC_{N_1}_{N_2}")]

        numb_inst_config = 100
    
    if(group_name.startswith("tMC")):

        T_factor = group_name.split("_")[1]
        T_factor = T_factor[:-1]
        if(T_factor.startswith("0")):
            T_factor = "0." + T_factor[1:]
        T_factor = float(T_factor)


        N_1 = int(group_name.split("_")[2])
        N_2 = int(group_name.split("_")[3])

        N = max(N_1, N_2) + 1
        config_list = [(N, int(T_factor*N), -1, f"tMC_{N_1}_{N_2}")]

        numb_inst_config = 1000

    meta_info = torch.zeros(numb_inst_config, meta_M)

    inst_norm = torch.zeros(numb_inst_config)

    rew_inst = np.zeros(numb_inst_config)
    obj_inst = np.zeros(numb_inst_config)
    obj_t10_inst = np.zeros(numb_inst_config)
    obj_max_inst = np.zeros(numb_inst_config)
  

#return J, GE_base, N, T, rho, MIS (bool), meta, clss
def get_config(inst_config_idx):
    rng = np.random.default_rng(seed = inst_config_idx)
    
    
    N,T,rho,clss = config_list[inst_config_idx % len(config_list)]
    #N,T,rho,clss = config_list[int(len(config_list)*rng.random())]
    
    inst_seed = inst_config_idx + inst_seed_offset

    #base GE only relevent for planted
    GE_base, J = instance_gen.gen_problem(N, inst_seed, clss)
    
    MIS = clss.startswith("MIS") or clss.startswith("tMIS") or clss.startswith("MC") or clss.startswith("tMC")
    
    
    J = torch.tensor(J, dtype = torch.float32)
    return J, inst_norm[inst_config_idx], inst_seed, GE_base, N, T, rho, MIS, meta_info[inst_config_idx, :], clss

def get_current_MIS_size():
    return instance_gen.current_MIS_size


def preprocess_inst(use_rho = True):
    global meta_info 

    global inst_norm 
    global inst_N, inst_T, inst_clss, N_all, T_all, clss_all


    inst_N = np.zeros(numb_inst_config)
    inst_T = np.zeros(numb_inst_config)

    for i in range(numb_inst_config):
        print(f"{i}/{numb_inst_config}")
        J, norm, inst_seed, GE_base, N, T, rho, MIS, meta, clss = get_config(i)
        
        inst_N[i] = N
        inst_T[i] = T
        inst_clss.append(clss)
        N_all.add(N)
        T_all.add(T)
        clss_all.add(clss)
        
        norm = (torch.mean(J**2)*N)**(-0.5)
        
        if(1):
            meta[11] = N**(-1)
            meta[12] = N**(-0.5)
            meta[13] = np.log(N)
            
            meta[14] = T**(-1)
            meta[15] = T**(-0.5)
            #meta[16] = np.log(T)
        
        if(0):
            meta[0] = N**(-1)
            meta[1] = N**(-0.5)
            meta[2] = np.log(N)
            
            meta[3] = T**(-1)
            meta[4] = T**(-0.5)
            meta[5] = np.log(T)
    
            meta[6] = max(0, rho**(-1))
            if(not use_rho):
                meta[6] = 0
            
            meta[7] = 1*MIS
    
            ### graph properties
    
            #sparity
            meta[8] = torch.sum(J**2 > 0)/N**2
            #maximum degree / N
            meta[9] = torch.max(torch.sum(J**2 > 0, axis = 0))/N
            #minimum degree / N
            meta[10] = torch.min(torch.sum(J**2 > 0, axis = 0))/N
            
            ### matrx properties
            
            #maximum magnitude row
            meta[11] = torch.max(torch.mean(J**2, axis = 0)**0.5)/torch.mean(J**2)**0.5
            #minimum magnitude row 
            meta[12] = torch.min(torch.mean(J**2, axis = 0)**0.5)/torch.mean(J**2)**0.5
            
            #maximum magnitide element
            meta[13] = torch.max(torch.abs(J))/torch.mean(J**2)**0.5
    
            ### spectral properties
    
            L, V = torch.linalg.eig(J)
            meta[14] = torch.min(torch.real(V))/torch.mean(J**2)**0.5
            meta[15] = torch.max(torch.real(V))/torch.mean(J**2)**0.5
    
            #moments of spectrum
            meta[16] = torch.mean(torch.real(V)**2)/torch.mean(J**2)
            meta[17] = torch.mean(torch.real(V)**3)/torch.mean(J**2)**(3/2)
            meta[18] = torch.mean(torch.real(V)**4)/torch.mean(J**2)**(4/2)
        
        
        
        meta_info[i,:] = meta
        inst_norm[i] = norm
    
    #normalize meta_info

    meta_norm_mean = torch.mean(meta_info, axis = 0)
    meta_norm_std = torch.std(meta_info, axis = 0)

    meta_info = (meta_info - meta_norm_mean.reshape(1,-1))/(meta_norm_std.reshape(1,-1) + 0.1)
    
    return meta_info, inst_norm, meta_norm_mean, meta_norm_std 

