#!/usr/bin/env python
# coding: utf-8

# In[1]:


#from pulp.apis import GUROBI_CMD
from gurobipy import Model, GRB, quicksum
import h5py
import numpy as np
from itertools import product
from pulp import LpProblem, LpVariable, LpMaximize, lpSum, GLPK, PULP_CBC_CMD, LpStatus
from pulp.apis import CPLEX_CMD
import matplotlib.pyplot as plt
from mip import OptimizationStatus
import time


# In[2]:


import os 


# In[3]:


def shift_with_boundary_preservation(array, shift_range):
    # Initialize an array to store the shifted matrices
    shifted_matrices = np.zeros((len(shift_range),) + array.shape, dtype=array.dtype)

    # Generate shifts and store the copies
    for i, shift in enumerate(shift_range):
        if shift == 0:
            shifted_matrices[i] = array.copy()  # No shift, simply copy the original array
        elif shift > 0:
            shifted_matrix = np.zeros_like(array)
            shifted_matrix[:-shift, :] = array[shift:, :]
            shifted_matrices[i] = shifted_matrix
        else:
            shifted_matrix = np.zeros_like(array)
            shifted_matrix[-shift:, :] = array[:shift, :]
            shifted_matrices[i] = shifted_matrix

    return shifted_matrices




# In[ ]:


# load data
for channel in range(1,1024):
    print('channel = ' + str(channel), flush=True)
    for noise in range(1,2):
        for mat_ctr in range(1,5):

            try:    
                file_path = r'C:\Users\user\OneDrive - NVIDIA Corporation\103backup\synthetic\err_matrix_binary_4taps_badsynthchannel_' + str(channel) + '_noise_' + str(noise) + '_ctr_' + str(mat_ctr) + '.mat'
                                
                if os.path.exists(file_path):
                    print('valid file')
                else:
                    print('invalid file')

                ##err_matrix_binary_4taps_steps_hardchannel_' + str(iters) + '.mat'4_noise_1_ctr_64.mat
                #f = h5py.File('err_matrix_binary_4taps_steps_iters_8.mat','r')
                f = h5py.File(file_path)
                vorg = f.get('err_matrix_bin')
                
                res_file_path = 'unbalanced_4lvl_badsynthchannel_4tap_channel_' + str(channel) + '_noise_' + str(noise) + '_ctr_' + str(mat_ctr) + '_labels.npy'
                  
                if os.path.exists(res_file_path):
                    print('res already exist',flush=True)
                    continue
                #trim for training
                p = 80
                n = 32
                
                err_mat = vorg[int(64-p/2):int(64+p/2),:,:]
                
                print(err_mat.shape,flush=True)
    
                # parameters 
                
                num_levels = 4
                k = 10
                k_step = 1
                cursors = 4
                m = 2**cursors
                shift_range = np.arange(-k,k+k_step,k_step)
                kp = len(shift_range)
                
                
                # create shifted eye 
                sub_eye = np.zeros((2**cursors,len(shift_range),n,p))
                
                for i in range(m):
                #plt.Figure()
                    shifted = shift_with_boundary_preservation(err_mat[:,:,i],shift_range)
                    sub_eye[i,:,:,:] = np.transpose(shifted, axes=(0, 2, 1))
                #plt.pcolor(err_mat[:,:,0])
                #plt.show()
                
                
                # PULP ILP Model 
                #model = LpProblem("SL_Problem", LpMaximize)
                model = Model("SL_Problem")
                model.setParam('Cuts', 2)  # Increase cut aggressiveness
                model.setParam('Heuristics', 0.1)  # Increase heuristic efforts
                model.setParam('VarBranch', 0)  # Default branching with strong branching at the root
                model.setParam('ConcurrentMIP', 8)  # Number of concurrent solvers
                model.setParam('Presolve', 2)  # Aggressive presolve
                model.setParam('OutputFlag', 0)
                model.setParam('TimeLimit', 90) #2min max 
                #model.setParam('MIPGap', 1e-6)  # Tighten the MIP gap
                model.setParam('MIPFocus', 2)  # Focus on proving optimality
    
                # Define X as a 2D list of binary variables
                X = [[model.addVar(vtype=GRB.BINARY, name=f"x({i},{l})") for l in range(kp)] for i in range(m)]
                
                # Define W as a 2D list of binary variables
                W = [[model.addVar(vtype=GRB.BINARY, name=f"W_{j}_{z}") for z in range(p)] for j in range(n)]
                
                # Define unique_levels as a list of binary variables
                unique_levels = [model.addVar(vtype=GRB.BINARY, name=f"unique_level_{l}") for l in range(kp)]
    
    
                
                # New variables to count assignments to each level
                assignment_count = [model.addVar(vtype=GRB.INTEGER, name=f"count_level_{l}") for l in range(kp)]
    
                # Integrate new variables into the model
                #model.update()
                
                midpoint = kp // 2
                non_zero_indices = []
                all_one_midpoint  = set()
                # First, gather indices and check conditions
                for j in range(n):
                    for z in range(p):
                        all_one = True
                        for i in range(m):
                            all_one_i = sub_eye[i][midpoint][j][z] == 1
                            for l in range(kp):
                                if sub_eye[i][l][j][z] != 0:
                                    non_zero_indices.append((i, l, j, z))
                            all_one &= all_one_i
                
                        if all_one:
                            all_one_midpoint.add((j, z))
                
                for l in range(kp):
                    model.addConstr(assignment_count[l] == quicksum(X[i][l] for i in range(m)), name=f"Count_Assign_{l}")
    
                
                # Add conditions based on all_one_midpoint directly
                for j in range(n):
                    for z in range(p):
                        if (j, z) in all_one_midpoint:
                            W[j][z].lb = 1  # Set lower bound to 1
                            W[j][z].ub = 1  # Set upper bound to 1
                
                
                '''            
                # Create auxiliary variables and constraints for balancing
                for l1 in range(kp):
                    for l2 in range(l1 + 1, kp):
                        # Auxiliary variable for indicating if both l1 and l2 are active
                        both_active = model.addVar(vtype=GRB.BINARY, name=f"both_active_{l1}_{l2}")
                        
                        # Add constraint that 'both_active' is 1 if both l1 and l2 are uniquely active
                        model.addGenConstrAnd(both_active, [unique_levels[l1], unique_levels[l2]], name=f"active_pair_{l1}_{l2}")
                        
                        # Add conditional constraint that balances the assignment counts if both are active
                        model.addGenConstrIndicator(both_active, True, assignment_count[l1] == assignment_count[l2], name=f"Balance_{l1}_{l2}")
                '''
                            
                        
                #model.update()
                start_time = time.time()
                            
    
                # Constraint to ensure each X[i] sums to 1
                #for i in range(m):
                #    model.addConstr(quicksum(X[i][l] for l in range(kp)) == 1, name=f"SumX_{i}_to_1")
                model.addConstrs((quicksum(X[i][l] for l in range(kp)) == 1 for i in range(m)), name="SumX_to_1")
    
                '''
                '''
                # Constraints linking X and unique_levels
                #for l in range(kp):
                #    for i in range(m):
                #        model.addConstr(X[i][l] <= unique_levels[l], name=f"Link_X_{i}_{l}_to_UL_{l}")
    
                model.addConstrs((X[i][l] <= unique_levels[l] for l in range(kp) for i in range(m)), name="Link_X_to_UL")
    
                '''
                
                '''
                # Constraint that sums of unique_levels equals num_levels
                model.addConstr(quicksum(unique_levels) == num_levels, name="Sum_unique_levels")
                '''
    
    
                '''
                # Symmetry constraints for unique_levels
                for l in range(kp // 2):
                    model.addConstr(unique_levels[l] == unique_levels[kp - 1 - l], name=f"Symmetry_{l}")
                
          
                model.setObjective(quicksum(W[j][z] for j in range(n) for z in range(p)), GRB.MAXIMIZE)
                
                for j in range(n):
                    for z in range(p):
                        # Calculate intersection only using non_zero_indices and apply constraints if not all_one at midpoint
                        intersection = quicksum(X[i][l] * sub_eye[i][l][j][z] 
                                                for (i, l, j_, z_) in non_zero_indices 
                                                if j_ == j and z_ == z)
                        model.addConstr(intersection >= m * W[j][z], name=f"intersec_ge_{j}_{z}")
                        model.addConstr(intersection <= m - 1 + m * W[j][z], name=f"intersec_le_{j}_{z}")
    
                # Update and optimize model
                model.update()
                model.optimize()
    
                # Output results
                if model.status == GRB.OPTIMAL:
                    print('optimal')
                    end_time = time.time()
                    execution_time = end_time - start_time
                    print("Setup time: ", execution_time, " seconds",flush=True)
                
                    base_eye = np.prod(vorg, axis=2)
                    base_sum = sum(sum(base_eye==1))
                
                    print('base eye area is ' + str(base_sum),flush=True)
                        
                    area_sum = 0
                    
                    # Initialize lists to store x, y coordinates
                    x_coordinates = []
                    y_coordinates = []
                    levels = []
                    
                    # Check if the model has an optimal solution
                    if model.status == GRB.OPTIMAL:
                        print('Optimal solution found:')
                        for v in model.getVars():
                            if v.X > 0.9:  # Only consider variables that are active in the solution
                                if 'W_' in v.VarName:  # Check if the variable is a W variable
                                    area_sum += 1
                                    parts = v.VarName.split('_')
                                    x = int(parts[1])
                                    y = int(parts[2])
                                    x_coordinates.append(x)
                                    y_coordinates.append(y)
                    
                                if 'unique_level_' in v.VarName:  # Check if the variable is a unique_level variable
                                    levels.append(int(v.VarName.split('_')[2]))
                    
                    # Output results
                    print(f"Total Area: {area_sum}")
                    #print(f"X Coordinates: {x_coordinates}")
                    #print(f"Y Coordinates: {y_coordinates}")
                    #print(f"Levels Activated: {levels}")
                
                    
                    categories = [-1] * 16  # Initialize categories list with placeholder values
                    levels = []  # Initialize levels list
                    
                    # First, extract 'x' variables indicating category assignments
                    for v in model.getVars():
                        if v.X > 0.9 and 'x' in str(v.VarName):
                        #if 'x' in v.name and v.varValue == 1.0:
                            # Extract entry and level from variable name
                            #print(v)                
                            str_tmp = str(v.varName)
                            entry, level = map(int, str_tmp[2:-1].split(','))
                            print(entry)
                            print(level)
                            categories[entry] = level  # Assign level to corresponding entry
                    
                    #print('finished v')
                    # Next, extract 'unique_level' variables indicating available levels
                    for v in model.getVars():
                        
                        if 'level' in str(v.varName) and v.X == 1.0:
                            # Extract level from variable name
                            str_tmp = str(v.varName)
                            level = int(str_tmp.split('_')[2])
                            levels.append(level)  # Add unique level to levels list
                    
                    category_map = {level: i for i, level in enumerate(sorted(levels))}
                    categories = [category_map[level] for level in categories]
                    
                    # Combine categories and sorted levels to form final labels
                    labels = categories + sorted(levels) 
                    labels.append(base_sum)
                    labels.append(area_sum)
                    
                    print("Labels:", labels)
                    
                    print(type(labels))
                    
                    lfn = 'unbalanced_4lvl_badsynthchannel_4tap_channel_' + str(channel) + '_noise_' + str(noise) + '_ctr_' + str(mat_ctr) + '_labels.npy'
                    
                    np.save(lfn, np.squeeze(np.array(labels)).astype(np.int32))
                
                elif model.status == GRB.INF_OR_UNBD:
                    print('Model is infeasible or unbounded')
                else:
                    print('Optimization ended with status:', model.status)
            except:
                print('bad key index ' + file_path)
                continue
            
        # start_time = time.time()
        
        # # Solve the problem
        # #model.solve()
        # status = model.solve(GUROBI_CMD(msg=0))
        # end_time = time.time()
        # execution_time = end_time - start_time
        # print("Solve time: ", execution_time, " seconds",flush=True)



