from models_rl import SimplePolicy, AgentReinforce, GNNPolicy, BipartiteGraphConvolution, BipartiteNodeData
import yaml
import random
import itertools
import argparse
import gurobipy as gp
from pyscipopt import *
from pyscipopt import Model
import pyscipopt
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
import joblib
import pandas as pd
import numpy as np
import warnings
import torch
import time
import gzip
import pickle
import csv
import numpy as np
import scipy.sparse as sp
from dataset import InstanceDataset, custom_collate
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import pathlib
import ecole
from ecole_extend.environment_extend import SimpleConfiguring, SimpleConfiguringEnablecuts, SimpleConfiguringEnableheuristics
from utility import lbconstraint_modes, instancetypes, incumbent_modes, instancesizes, generator_switcher, binary_support, copy_sol, mean_filter,mean_forward_filter, imitation_accuracy, haming_distance_solutions, haming_distance_solutions_asym


import gurobipy as gp
class addBranching(object):
    def __init__(self):
        self.p_fail = 1e-3
    def add_branching(self, mip_model, mip_sol):
        mip_model = Model()
        mip_model.readProblem('/Users/bula/Downloads/model.cip')
        vars = mip_model.getVars()
        n_binvars = mip_model.getNBinVars()
        cons_vars_0 = []
        cons_vars_1 = []
        for i in range(0, n_binvars):
            val = mip_model.getSolVal(mip_sol, vars[i])
            
            assert mip_model.isFeasIntegral(val), "Error: Solution passed to LB is not integral!"

            if mip_model.isFeasEQ(val, 1):
                cons_vars_1.append(vars[i])
            elif mip_model.isFeasEQ(val, 0):
                cons_vars_0.append(vars[i])
            assert vars[i].vtype() == "BINARY", "Error: local branching constraint uses a non-binary variable!"
        k_up = len(cons_vars_1) - np.sqrt(len(cons_vars_1) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        k_low = len(cons_vars_0) + np.sqrt(len(cons_vars_0) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        if k_low == 0:
            k_low = 1
        if len(cons_vars_1) != 0:
            mip_model.addCons(quicksum(cons_vars_1[i] for i in range(len(cons_vars_1)))>= np.ceil(k_up) + 1, name="mvb_up")
        if len(cons_vars_0) != 0:
            mip_model.addCons(quicksum(cons_vars_0[i] for i in range(len(cons_vars_0)))<= np.floor(k_low) - 1, name="mvb_low")

        print(np.ceil(k_up) + 1, np.floor(k_low) - 1)
        print(len(cons_vars_1), len(cons_vars_0))
        return mip_model 
    def add_branching_groups_fixed(self,solution_map, branching_list): 
        model_add_cons = Model()
        model_add_cons.readProblem('/Users/bula/Downloads/model.cip')    
        for var in model_add_cons.getVars():
            if var.name in branching_list:
                model_add_cons.addCons(var == solution_map[var.name])
        return model_add_cons
    def add_branching_groups_onlybranching(self, variables, solution): 
        env = gp.Env()
        model_add_cons = self.model.copy(env=env)
        vars1 = []
        vars0 = []
        for i in range(len(solution)):
            if solution[i] == 1:
                vars1.append(model_add_cons.getVarByName(variables[i]))
            elif solution[i] == 0:
                vars0.append(model_add_cons.getVarByName(variables[i]))
            
        k_up = len(vars1) - np.sqrt(len(vars1) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        k_low = len(vars0) + np.sqrt(len(vars0) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        print(len(vars1), len(vars0))
        print(k_up, k_low)
        model_add_cons.addConstr(gp.quicksum(vars1[i] for i in range(len(vars1)))>= np.ceil(k_up) + 1, name="mvb_up")
        model_add_cons.addConstr(gp.quicksum(vars0[i] for i in range(len(vars0)))<= np.floor(k_low) - 1, name="mvb_low")
        return model_add_cons

    def add_branching_random_fixed(self,solution_map):  
        model_add_cons = Model()
        model_add_cons.readProblem('/Users/bula/Downloads/model.cip')
        vars1 = []
        vars0 = []
        for var in model_add_cons.getVars():
            score = random.random()
            if score < 1:
                model_add_cons.addCons(var == solution_map[var.name])
            else:
                if solution_map[var.name] == 1.0:
                    vars1.append(var)
                elif solution_map[var.name] == 0.0:
                    vars0.append(var)
        k_up = len(vars1) - np.sqrt(len(vars1) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        k_low = len(vars0) + np.sqrt(len(vars0) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        
        if len(vars1) != 0:
            model_add_cons.addCons(quicksum(vars1[i] for i in range(len(vars1)))>= np.ceil(k_up) + 1, name="mvb_up")
        if len(vars0) != 0:
            model_add_cons.addCons(quicksum(vars0[i] for i in range(len(vars0)))<= np.floor(k_low) - 1, name="mvb_low")
        
        return model_add_cons
    def add_branching_branching_fixed(self,solution_map, branching_list):  
        model_add_cons = Model()
        model_add_cons.readProblem('/Users/bula/Downloads/model.cip')
        vars1 = []
        vars0 = []
        for var in model_add_cons.getVars():
            if var.name in branching_list: 
                if solution_map[var.name] == 1.0:
                    vars1.append(var)
                elif solution_map[var.name] == 0.0:
                    vars0.append(var)
            else:
                score = random.random()
                if score < 1:
                    model_add_cons.addCons(var == solution_map[var.name])
        k_up = len(vars1) - np.sqrt(len(vars1) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        k_low = len(vars0) + np.sqrt(len(vars0) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        
        if len(vars1) != 0:
            model_add_cons.addCons(quicksum(vars1[i] for i in range(len(vars1)))>= np.ceil(k_up) + 1, name="mvb_up")
        if len(vars0) != 0:
            model_add_cons.addCons(quicksum(vars0[i] for i in range(len(vars0)))<= np.floor(k_low) - 1, name="mvb_low")
        
        return model_add_cons
    def add_branching_groups_together(self,solution_map, branching_list):
        model_add_cons = Model()
        model_add_cons.readProblem('/Users/bula/Downloads/model.cip')    
        vars1 = []
        vars0 = []
        fix_var = []
        fix_var_sol = []
        
        for var in model_add_cons.getVars():
            if var.name in branching_list:
                if solution_map[var.name] == 1.0:
                    vars1.append(var)
                elif solution_map[var.name] == 0.0:
                    vars0.append(var)
        k_up = len(vars1) - np.sqrt(len(vars1) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        k_low = len(vars0) + np.sqrt(len(vars0) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        
        # if len(vars1) != 0 and len(vars0) != 0:
        #     model_add_cons.addCons(quicksum(vars0[i] for i in range(len(vars0)))-quicksum(vars1[i] for i in range(len(vars1)))<= np.floor(k_low) - 1 - np.ceil(k_up) - 1, name="mvb_up")
        if len(vars1) != 0:
            model_add_cons.addCons(quicksum(vars1[i] for i in range(len(vars1)))>= np.ceil(k_up) + 1, name="mvb_up")
        if len(vars0) != 0:
            model_add_cons.addCons(quicksum(vars0[i] for i in range(len(vars0)))<= np.floor(k_low) - 1, name="mvb_low")
        
        
        return model_add_cons
    def add_branching_groups(self,solution_map, branching_list):  
        model_add_cons = Model()
        model_add_cons.readProblem('/Users/bula/Downloads/model.cip')    
        vars1 = []
        vars0 = []
        fix_var = []
        fix_var_sol = []
        
        for var in model_add_cons.getVars():
            if var.name in branching_list:
                if solution_map[var.name] == 1.0:
                    vars1.append(var)
                elif solution_map[var.name] == 0.0:
                    vars0.append(var)

        k_up = len(vars1) - np.sqrt(len(vars1) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        k_low = len(vars0) + np.sqrt(len(vars0) * np.log(1 / np.maximum(self.p_fail, 1e-15)) / 2)
        print(len(vars1), len(vars0))
        print(k_up, k_low)
        if len(vars1) != 0:
            model_add_cons.addCons(quicksum(vars1[i] for i in range(len(vars1)))>= np.ceil(k_up) + 1, name="mvb_up")
        if len(vars0) != 0:
            model_add_cons.addCons(quicksum(vars0[i] for i in range(len(vars0)))<= np.floor(k_low) - 1, name="mvb_low")
        
       
        
        
        return model_add_cons
class RLChooseGroup(object):
    def __init__(self,addBranching, instance_type, instance_size, lbconstraint_mode, incumbent_mode, seed, total_time_available, each_time_available, stop_noimprove_num):
        self.addBranching = addBranching
        self.instance_type = instance_type
        self.instance_size = instance_size
        self.incumbent_mode = incumbent_mode
        self.lbconstraint_mode = lbconstraint_mode
        self.group_number = 0
        self.each_group_number = 10 
        self.p_fail = 1e-3
        self.eps = .0000001
        self.actions = {'decrease': 0, 'increase':1 , 'reset': 2} 
        
        self.groups = [] 
        self.n_groups_stepsize = 0.2
        self.solution = [] 
        self.state = np.zeros(52) 
        self.records_var_branch = np.zeros(52)
        self.obj = 1e20  
        self.improve = 0  
        self.no_improve = 0  
        self.total_iter_limit_N = 400
        self.total_time_available_input = total_time_available
        self.each_time_available = each_time_available
        self.stop_noimprove_num = stop_noimprove_num

        self.initialize_ecole_env()
        self.env.seed(seed)  

        self.selection_history = []

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_N_path = './results/saved_models/setcovering_graph.pkl'
        


        parser = argparse.ArgumentParser()
        parser.add_argument('--regression_model_path', type = str, default='./results/saved_models/regression/trained_params_mean_setcover-independentset-combinatorialauction_asymmetric_firstsol_k_prime_epoch163.pth')
        parser.add_argument('--rl_model_path', type = str, default='./results/saved_models/rl/reinforce/SAA/checkpoint_trained_reward3_simplepolicy_rl4lb_reinforce_trainset_setcovering_graph-small_0.1trainset_lr0.01_epochs1.pth')
        parser.add_argument('--yaml_path', type = str, default='./model_formulation/yaml_path/Scenario1_bonded.yaml')
        parser.add_argument('--samplsize_model_path', type = str, default=self.model_N_path)

        args = parser.parse_args()
        self.regression_model_path = args.regression_model_path
        self.rl_model_path = args.rl_model_path 
        self.samplsize_model_path = args.samplsize_model_path 
        self.yaml_path =  args.yaml_path
    def initialize_ecole_env(self):

        if self.incumbent_mode == 'firstsol':

            self.env = ecole.environment.Configuring(

                
                scip_params={
                    "presolving/maxrounds": 0,  
                    "presolving/maxrestarts": 0,
                },

               
                observation_function=ecole.observation.MilpBipartite(),
                reward_function=None,

                # collect additional metrics for information purposes
                information_function={
                    'time': ecole.reward.SolvingTime().cumsum(),
                }
            )

        elif self.incumbent_mode == 'rootsol':

            if self.instance_type == 'independentset':
                self.env = SimpleConfiguring(

                    
                    scip_params={
                        "presolving/maxrounds": 0,  # deactivate presolving
                        "presolving/maxrestarts": 0,
                    },

                    observation_function=ecole.observation.MilpBipartite(),

                    reward_function=None,

                    # collect additional metrics for information purposes
                    information_function={
                        'time': ecole.reward.SolvingTime().cumsum(),
                    }
                )
            else:
                self.env = SimpleConfiguringEnablecuts(

                    
                    scip_params={
                        "presolving/maxrounds": 0,  # deactivate presolving
                        "presolving/maxrestarts": 0,
                    },

                    observation_function=ecole.observation.MilpBipartite(),

                    reward_function=None,

                    # collect additional metrics for information purposes
                    information_function={
                        'time': ecole.reward.SolvingTime().cumsum(),
                    }
                )
    def load_mip_dataset(self, instances_directory=None, sols_directory=None, incumbent_mode=None):
        instance_filename = f'{self.instance_type}-*_transformed.cip'
        sol_filename = f'{incumbent_mode}-{self.instance_type}-*_transformed.sol'

        train_instances_directory = instances_directory + 'train/'
        instance_files = [str(path) for path in sorted(pathlib.Path(train_instances_directory).glob(instance_filename), key=lambda path: int(path.stem.replace('-', '_').rsplit("_", 2)[1]))]

        instance_train_files = instance_files[:int(5/8 * len(instance_files))]
        
        instance_valid_files = instance_files[int(7/8 * len(instance_files)):]
        
        test_instances_directory = instances_directory + 'test/'
        
        instance_test_files = [str(path) for path in sorted(pathlib.Path(test_instances_directory).glob(instance_filename),
                                                    key=lambda path: int(
                                                        path.stem.replace('-', '_').rsplit("_", 2)[1]))]

        
        train_sols_directory = sols_directory + 'train/'
        sol_files = [str(path) for path in sorted(pathlib.Path(train_sols_directory).glob(sol_filename), key=lambda path: int(path.stem.replace('-', '_').rsplit("_", 2)[1]))]

        sol_train_files = sol_files[:int(7/8 * len(sol_files))]
        sol_valid_files = sol_files[int(7/8 * len(sol_files)):]
        

        test_sols_directory = sols_directory + 'test/'
        sol_test_files = [str(path) for path in sorted(pathlib.Path(test_sols_directory).glob(sol_filename),
                                                key=lambda path: int(path.stem.replace('-', '_').rsplit("_", 2)[1]))]

        train_dataset = InstanceDataset(mip_files=instance_train_files, sol_files=sol_train_files)
        valid_dataset = InstanceDataset(mip_files=instance_valid_files, sol_files=sol_valid_files)
        test_dataset = InstanceDataset(mip_files=instance_test_files, sol_files=sol_test_files)
        
        return train_dataset, valid_dataset, test_dataset
    def load_test_mip_dataset(self, instances_directory=None, sols_directory=None, incumbent_mode=None):
        instance_filename = f'{instance_type}-*_transformed.cip'
        sol_filename = f'{incumbent_mode}-{instance_type}-*_transformed.sol'

        if self.instance_type == 'miplib2017_binary' or self.instance_type == 'miplib_39binary':
            test_instances_directory = instances_directory
        else:
            test_instances_directory = instances_directory + 'test/'
        instance_test_files = [str(path) for path in sorted(pathlib.Path(test_instances_directory).glob(instance_filename),
                                                        key=lambda path: int(
                                                            path.stem.replace('-', '_').rsplit("_", 2)[1]))]

        
        if self.instance_type == 'miplib2017_binary' or self.instance_type == 'miplib_39binary':
            test_sols_directory = sols_directory
        else:
            test_sols_directory = sols_directory + 'test/'
       
        sol_test_files = [str(path) for path in sorted(pathlib.Path(test_sols_directory).glob(sol_filename),
                                                    key=lambda path: int(path.stem.replace('-', '_').rsplit("_", 2)[1]))]

        test_dataset = InstanceDataset(mip_files=instance_test_files, sol_files=sol_test_files)

        return test_dataset
    def devide_groups(self, MIP_model, normalize='row', topk=10, method='louvain', n_groups=10, 
                                  seed=0, weight_abs=True, min_weight=0.0):
        instance = ecole.scip.Model.from_pyscipopt(MIP_model)
        observation, action_set, reward_offset, done, info = self.env.reset(instance)
  
        idx = observation.edge_features.indices
        vals = observation.edge_features.values
        cons_idx = idx[0]
        var_idx = idx[1]
        m = observation.constraint_features.shape[0]
        n = observation.variable_features.shape[0]
        variable_names = []
        for i in range(n):
            try:
               
                var = MIP_model.getVars()[i]
                variable_names.append(var.name)
            except:
                variable_names.append(f"x_{i}")
        data = np.abs(vals) if weight_abs else vals
        A = sp.coo_matrix((data, (cons_idx, var_idx)), shape=(m, n)).tocsr()

        if normalize == 'row':
            row_norms = np.sqrt(A.power(2).sum(axis=1)).A1 + 1e-12
            A = sp.diags(1.0 / row_norms) @ A
        elif normalize == 'col':
            col_norms = np.sqrt(A.power(2).sum(axis=0)).A1 + 1e-12
            A = A @ sp.diags(1.0 / col_norms)

        W = (A.T @ A).tocsr()
        W.setdiag(0.0)
        W.eliminate_zeros ()
        if min_weight > 0.0:
            W = W.multiply(W > min_weight)
            W.eliminate_zeros()

        if topk is not None and topk < n:

            W_lil = W.tolil()
            
            for j in range(n):
                row_indices = W_lil.rows[j]
                data_values = W_lil.data[j]
                
                if len(data_values) > topk:
                    topk_idx = np.argpartition(data_values, -topk)[-topk:]
                    W_lil.rows[j] = [row_indices[i] for i in topk_idx]
                    W_lil.data[j] = [data_values[i] for i in topk_idx]
            
            W = W_lil.tocsr()
            W = W.maximum(W.T)
        groups = None
        try:
            if method == 'louvain':
                import networkx as nx
                import community as community_louvain  # python-louvain
                G = nx.Graph()
                G.add_nodes_from(range(n))
                coo = W.tocoo()
                for i, j, w in zip(coo.row, coo.col, coo.data):
                    if i < j:
                        G.add_edge(int(i), int(j), weight=float(w))
                parts = community_louvain.best_partition(G, weight='weight', random_state=seed)
                labels = np.array([parts.get(i, 0) for i in range(n)])
            elif method == 'spectral':
                from sklearn.cluster import SpectralClustering
                maxw = W.data.max() if W.nnz > 0 else 1.0
                S = (W / maxw).astype(np.float64)
                if n_groups is None:
                    deg = np.array(S.sum(axis=1)).ravel()
                    L = sp.diags(deg) - S
                    evals = np.linalg.eigvalsh(L.toarray()) if n <= 1500 else np.linalg.eigvalsh(L[:300, :300].toarray())
                    gaps = np.diff(np.sort(evals))
                    k = int(np.argmax(gaps[: min(20, len(gaps))]) + 1) if len(gaps) > 0 else 6
                    n_groups = max(2, min(n, k))
                sc = SpectralClustering(n_clusters=n_groups, affinity='precomputed', random_state=seed, assign_labels='kmeans')
                labels = sc.fit_predict(S.toarray() if S.shape[0] <= 2000 else (S + sp.eye(n)).astype(np.float64))
            else:
                from sklearn.cluster import AgglomerativeClustering
                maxw = W.data.max() if W.nnz > 0 else 1.0
                S = (W / maxw).astype(np.float64)
                D = (1.0 - S).toarray() if n <= 2000 else None
                if n_groups is None:
                    n_groups = 6
                if D is None:
                    # 回退到谱聚类
                    from sklearn.cluster import SpectralClustering
                    sc = SpectralClustering(n_clusters=n_groups, affinity='precomputed', random_state=seed)
                    labels = sc.fit_predict(S.toarray())
                else:
                    labels = AgglomerativeClustering(n_clusters=n_groups, affinity='precomputed', linkage='average').fit_predict(D)
        except Exception as e:
            deg = np.array(W.sum(axis=1)).ravel()
            order = np.argsort(-deg)
            k = n_groups or max(2, n // max(50, 1))
            labels = np.empty(n, dtype=int)
            for i, v in enumerate(order):
                labels[v] = i % k
        groups = []
        for g in np.unique(labels):
            group_indices = np.where(labels == g)[0]
            group_vars = [variable_names[i] for i in group_indices]
            groups.append(group_vars)
        print("\n=== Group Sizes ===")
        for i, group in enumerate(groups):
            print(f"Group {i}: {len(group)} variables")
        self.groups = groups
        return groups
    def devide_groups_random(self, MIP_model):
        all_var_names = [var.name for var in MIP_model.getVars()]
        random.shuffle(all_var_names)

        self.groups = []
        self.group_number = self.each_group_number

        for _ in range(self.each_group_number):
            self.groups.append([])

        for i, var_name in enumerate(all_var_names):
            group_index = i % self.each_group_number
            self.groups[group_index].append(var_name)
    
    def caculate_groups_strong(self):
        scores = {}
        env = gp.Env()
        for g in self.groups:
            self.model_clone = self.model.copy()
            for var_name in g:
                var = self.model_clone.getVarByName(var_name)
                self.model_clone.addConstr(var==1, name='cons=1')
            self.model_clone.Params.MIPGap = 0.05
            self.model_clone.optimize()
            if self.model_clone.Status != gp.GRB.INFEASIBLE:
                up = self.model_clone.getObjective().getValue()
            else:
                up = 1e10
            self.model_clone = self.model.copy()
            for var_name in g:
                var = self.model_clone.getVarByName(var_name)
                self.model_clone.addConstr(var==0, 'cons=0')
            self.model_clone.Params.MIPGap = 1
            self.model_clone.optimize()
            if self.model_clone.Status != gp.GRB.INFEASIBLE:
                down = self.model_clone.getObjective.getValue()
            else:
                down = 0
            scores[tuple(g)] = abs(up-down)
        scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
        return scores
    def group_violation_majority(self, solution_map, group_idx):

        xg = np.array([solution_map[var] for var in group_idx])
        S1 = int(np.sum(xg >= 0.5))  
        S0 = int(np.sum(xg < 0.5)) 

        def _safe_logterm(s):
            return np.sqrt(max(s, 1e-12) * np.log(1.0 / max(self.p_fail, 1e-15)) / 2.0)

        v1 = 0
        if S1 > 0:
            kup = S1 - _safe_logterm(S1)
            rhs_up = int(np.ceil(kup)) + 1
            v1 = max(0, rhs_up - S1)

        v0 = 0
        if S0 > 0:
            klow = S0 + _safe_logterm(S0)
            rhs_low = int(np.floor(klow)) - 1
            v0 = max(0, S0 - rhs_low)

        return int(v1 + v0)


    def choose_groups_score(self, solution_map, alpha=0.5, topM=None):
        scores = []
        
        for gid, G in enumerate(self.groups):
            viol = self.group_violation_majority(solution_map, G)
            denom = max(len(G), 1) ** alpha
            score = viol / denom
            scores.append((gid, score, int(viol), len(G)))

        scores.sort(key=lambda x: x[1], reverse=True)
        if topM is not None:
            scores = scores[:topM]
        return scores
    def select_groups_randomly_with_history(self, n_groups_to_select, seed=None):
        if seed is not None:
            random.seed(seed)
        all_group_ids = list(range(len(self.groups)))
        weights = []
        for gid in all_group_ids:
            count = self.selection_history.count(gid)
            weight = 1.0 / (count + 1)
            weights.append(weight)

        selected_group_indices = random.choices(
            all_group_ids, 
            weights=weights, 
            k=n_groups_to_select
        )
        selected_group_indices = list(set(selected_group_indices))
        
        while len(selected_group_indices) < n_groups_to_select:
            remaining_ids = [gid for gid in all_group_ids if gid not in selected_group_indices]
            if not remaining_ids:
                break
            additional_id = random.choice(remaining_ids)
            selected_group_indices.append(additional_id)
        selected_variables = []
        for idx in selected_group_indices:
            selected_variables.extend(self.groups[idx])
    
        updated_history = self.selection_history + selected_group_indices
        self.selection_history = updated_history
        return selected_group_indices, selected_variables

    def select_groups_randomly(self, n_groups_to_select, seed=None):
        if seed is not None:
            random.seed(seed)
        all_group_ids = list(range(len(self.groups)))
        if n_groups_to_select >= len(all_group_ids):
            selected_group_indices = all_group_ids.copy()
        else:
            selected_group_indices = random.sample(all_group_ids, n_groups_to_select)

        selected_variables = []
        for idx in selected_group_indices:
            selected_variables.extend(self.groups[idx])

        return selected_group_indices, selected_variables
    def caculate_groups_LP(self, solution_map, topM):
        scores = []
        env = gp.Env()
        for g in self.groups:
            branching_model = self.addBranching.add_branching_groups_fixed(solution_map, g)
            
            branching_model.resetParams()
            branching_model.setPresolve(pyscipopt.SCIP_PARAMSETTING.OFF)
            branching_model.setHeuristics(pyscipopt.SCIP_PARAMSETTING.OFF)
            branching_model.setSeparating(pyscipopt.SCIP_PARAMSETTING.OFF)
            branching_model.setIntParam("lp/solvefreq", 0)
            branching_model.setParam("limits/nodes", 1)
            # MIP_model.setParam("limits/solutions", 1)
            branching_model.setParam("display/verblevel", 0)
            branching_model.setParam("lp/disablecutoff", 1)

            branching_model.setParam("lp/initalgorithm", 's')
            branching_model.setParam("lp/resolvealgorithm", 's')
            # branching_model.setParam('limits/time', current_time_limit)
            branching_model.optimize()
            
            obj = branching_model.getObjVal()
            print('obj', obj)
            if branching_model.getNSols() > 0:
                obj = branching_model.getObjVal()
                scores.append((g, self.obj - obj))
                print('LP obj gap', self.obj - obj)
           
        scores.sort(key=lambda x: x[1], reverse=True)
        scores = scores[:topM]
        branching_variables = [var for idx in [item[0] for item in scores] 
                        for var in self.groups[idx]]
        return branching_variables
    def update_N(self, action):
        switcher = {
            self.actions['decrease']: int(np.ceil(self.n_groups - self.n_groups_stepsize * self.n_groups)),
            self.actions['increase']: int(np.ceil(self.n_groups + self.n_groups_stepsize * self.n_groups)), 
            self.actions['reset']: self.default_n_groups

        }
        return switcher.get(action, 'Error: Invilid k action!')
    def get_var_solution(self, round):
        if self.branching_model.getStatus() != "infeasible" and (self.obj_now <= self.obj or round <=1):
            self.solution = [] 
            for var in self.variables:
                self.branching_model.getVal(var)
                
                self.solution.append(int(self.branching_model.getVal(var)))
                
        elif self.obj_now > self.obj:
            self.solution2 = []
            for var in self.variables:
                self.solution2[var] = self.solution.get(var)
            self.solution = self.solution2
        return


    def get_time_allocation(self):
        n = max(1, self.n_groups)
        n_ref = max(1, self.default_n_groups)

        alpha = getattr(self, "time_alpha", 1) 
        t_base = getattr(self, "base_time", self.each_time_available)
        min_time = getattr(self, "min_time", 0.5 * t_base)
        max_time = getattr(self, "max_time", 10.0 * t_base)

        time_limit = t_base * (n / n_ref) ** alpha
        time_limit = max(min_time, min(max_time, time_limit))
        return time_limit

    def update_agent(self, agent, optimizer):
        R = 0
        policy_losses = []
        returns = []
        for r in agent.rewards[::-1]:
            R = r + R
            returns.insert(0,R)
        returns = [float(x) for x in returns]
        if len(returns) == 1:
            returns = torch.tensor(returns)
            print('returns 1:', returns)
        else:
            returns = torch.tensor(returns)
            print('returns 1:', returns)
            returns = (returns - returns.mean()) / (returns.std() + self.eps)
            print('returns 2:', returns)
        with torch.set_grad_enabled(optimizer is not None):
            for log_prob, Return in zip(agent.log_probs, returns):
                policy_losses.append(-log_prob * Return)
                
            print(policy_losses)
            if optimizer is not None:
                optimizer.zero_grad()
        
        
                policy_losses = torch.stack(policy_losses).sum()
                # policy_losses = torch.cat(policy_losses).sum()
                policy_losses.backward()
                optimizer.step()

        del agent.rewards[:]
        del agent.log_probs[:]
        return agent, optimizer, R
    
    def solve(self, state, action, round, solution_map): 
        self.n_groups = self.update_N(action)
        _, branching_variables= self.select_groups_randomly_with_history(n_groups_to_select=self.n_groups)
        
        current_time_limit = self.get_time_allocation()
        branching_model = self.addBranching.add_branching_branching_fixed(solution_map, branching_variables)
        branching_model.setParam('display/verblevel', 0)
        branching_model.setParam('limits/time', current_time_limit)
        
        print('Ncons1_1', branching_model.getNConss())
        branching_model.optimize()
        print(branching_model.getStatus())
        
        if branching_model.getNSols() > 0:
            temp = 0
            self.obj_now = branching_model.getObjVal()
            time = branching_model.getSolvingTime()
            self.total_time_available -= time

            reward = (self.obj-self.obj_now)/abs(self.obj_init) * time   
            if((self.obj-self.obj_now)/abs(self.obj_now) < 0.005):
                self.improve += 1 
            else:
                self.improve = 0
            print('time:', time)
            print('Obj gap:', self.obj-self.obj_now)
            print('Obj:', self.obj_now)
            
        
            if(self.obj<self.obj_now):
                print('ops')
                self.no_improve += 1
                new_data = [time,self.obj]
                with open(self.output_file, mode="a", newline="", encoding="utf-8") as file:
                    writer = csv.writer(file)
                    writer.writerow(new_data) 
                self.current_solution = branching_model.getBestSol()
                branching_model.freeTransform()
                branching_model.freeProb()
                del branching_model
                return [self.graph, self.current_solution, self.incumbent_solution, self.obj_now, self.obj, time], reward, solution_map
            
            new_data = [time,self.obj_now]
            with open(self.output_file, mode="a", newline="", encoding="utf-8") as file:
                writer = csv.writer(file)
                writer.writerow(new_data) 

            
            self.no_improve = 0
            self.incumbent_solution = branching_model.getBestSol() 
            self.obj = self.obj_now
            print(self.state.size)
            print(branching_model.getNVars())
            solution_map = {}
            for var2 in branching_model.getVars():
                sol = branching_model.getSolVal(self.incumbent_solution, var2)
                solution_map[var2.name] = sol
            for var in branching_model.getVars():
                self.state[temp] = branching_model.getVal(var)
                if var.name in branching_variables:
                    self.records_var_branch[temp] = 1
                else:
                    self.records_var_branch[temp] = 0
                temp += 1
            self.graph = self.mip_to_graph(branching_model)
            self.current_solution = branching_model.getBestSol()
            branching_model.freeTransform()
            branching_model.freeProb()
            del branching_model
            return [self.graph, self.current_solution, self.incumbent_solution, self.obj_now, self.obj, time], reward, solution_map
        else:
            time = branching_model.getSolvingTime()
            self.total_time_available -= time
            reward = -time
            self.no_improve += 1
            print('Infeasible')
            branching_model.freeTransform()
            branching_model.freeProb()
            del branching_model
            return [self.graph, self.incumbent_solution, self.incumbent_solution, self.obj_now, self.obj, time], reward, solution_map

           

        
    def solve_initial(self, state, action, MIP_model):
        addbranching = self.addBranching
        branching_model = addbranching.add_branching(MIP_model, self.incumbent_solution)
        obj_list = []
        sol_list = []
        elapsed_t = 0
        self.state = np.zeros(branching_model.getNVars())
        self.records_var_branch = np.zeros(branching_model.getNVars())
        model_list = [branching_model]
        for model in model_list:
            model.setParam('limits/time', self.each_time_available)
            model.setParam('display/verblevel', 0)
            model.optimize()
            
            if model.getNSols() > 0:
                obj = model.getObjVal()
                incumbent_solution = model.getBestSol()
                obj_list.append(obj)
                sol_list.append(incumbent_solution)
                elapsed_t += model.getSolvingTime()
                print(model.getStatus())
                print('elapsed_t: ', elapsed_t)
                print('obj:', obj)
        min_idx, self.obj_now= min(enumerate(obj_list), key=lambda x: x[1])
        self.incumbent_solution = sol_list[min_idx]
        best_model = model_list[min_idx]
        self.obj_init = self.obj = self.obj_now
        reward = (self.obj-self.obj_now)/abs(self.obj_init) / elapsed_t
        
        temp = 0
        solution_map = {}
        for var in best_model.getVars():
            self.state[temp] = best_model.getVal(var)  
            self.records_var_branch[temp] = 1
            sol = best_model.getSolVal(self.incumbent_solution, var)
            solution_map[var.name] = sol
            temp += 1
        self.graph = self.mip_to_graph(best_model)
        new_data = [elapsed_t,self.obj_now]

        with open(self.output_file, mode="a", newline="", encoding="utf-8") as file:
            writer = csv.writer(file)
            writer.writerow([3, 3, self.instance_name])
            writer.writerow(obj_list)
            writer.writerow(new_data) 
        return [self.graph, self.incumbent_solution, self.incumbent_solution, self.obj_now, self.obj, time] , reward, solution_map
    def policy_vanilla(self, iter, state):
        iter = iter % 10
        if iter <= 3 and iter > 0:
            action = self.actions['increase']
        elif iter > 3 and iter <= 7:
            action = self.actions['decrease']
        elif iter == 0:
            action = self.actions['reset']
        else:
            action = self.actions['decrease']
        
        return action
    def train_agent_per_instance(self, agent, optimizer, MIP_model):
        compute_bits = 0
        start_time = time.time()
        state, reward, solution_map= self.solve_initial(self.state, self.actions['reset'], MIP_model)
        agent.rewards.append(reward)
        done = (compute_bits >= self.total_iter_limit_N)
        self.improve = 0
        self.no_improve = 0
        N_action = self.actions['reset']
        while done != 1:
            compute_bits += 1
            done = (compute_bits >= self.total_iter_limit_N)
            state, reward, solution_map = self.solve(state, N_action, compute_bits, solution_map)
            
            if self.improve >= self.stop_noimprove_num:
                done = 1
            if self.no_improve >= self.stop_noimprove_num:
                done = 1
            if optimizer is not None:
                if self.n_groups > 2:
                    N_action = agent.select_action(state)
                else:
                    N_action = self.actions['increase']
                if self.n_groups+2 >= len(self.groups):
                    N_action = self.actions['decrease']
            else:
                N_action = self.policy_vanilla(compute_bits, state)
            agent.rewards.append(reward)
            print('Round: {:.0f}'.format(compute_bits),
              'Number of branched groups:', self.n_groups,
              'Final reward: {:.4f}'.format(reward),
              )
        branching_model.freeTransform()
        branching_model.freeProb()
        del branching_model

        end_time = time.time()
        eplased_time = end_time - start_time
        new_data = [eplased_time, self.obj]
        with open("output_branching.csv", mode="a", newline="", encoding="utf-8") as file:
            writer = csv.writer(file)
            writer.writerow(new_data) 
        return agent, eplased_time
    def evaluate_agent_per_instance(self, agent, optimizer, MIP_model):

        compute_bits = 0

        self.instance_name = MIP_model.getProbName()
        print(self.instance_name)
        n_vars = MIP_model.getNVars()
        n_binvars = MIP_model.getNBinVars()
        print("N of variables: {}".format(n_vars))
        print("N of binary vars: {}".format(n_binvars))
        print("N of constraints: {}".format(MIP_model.getNConss()))

        start_time = time.time()
        state, reward, solution_map= self.solve_initial(self.state, self.actions['reset'], MIP_model)
        
        done = (compute_bits >= self.total_iter_limit_N) and (self.total_time_available <= 0)
        self.improve = 0
        self.no_improve = 0
        N_action = self.actions['reset']
        while done != 1:
            state, reward, solution_map = self.solve(state, N_action, compute_bits, solution_map)
            if self.improve >= self.stop_noimprove_num:
                done = 1
            if self.no_improve >= self.stop_noimprove_num:
                done = 1
            
            if optimizer is not None:
                if self.n_groups > 2:
                    N_action = agent.select_action(state)
                else:
                    N_action = self.actions['increase']
                if self.n_groups+2 >= len(self.groups):
                    N_action = self.actions['decrease']
            else:
                N_action = self.policy_vanilla(compute_bits, state)
        
            
            agent.rewards.append(reward)
            print('Round: {:.0f}'.format(compute_bits),
              'Number of branched groups:', self.n_groups,
              'Final reward: {:.4f}'.format(reward),
              )
            compute_bits += 1
            done = done or (compute_bits >= self.total_iter_limit_N) or (self.total_time_available <= 0)

        end_time = time.time()
        eplased_time = end_time - start_time
        return agent, eplased_time, self.obj
    def evaluate_agent_per_instance_onlybranching(self, agent, optimizer, MIP_model):

        compute_bits = 0

        instance_name = MIP_model.getProbName()
        print(instance_name)
        n_vars = MIP_model.getNVars()
        n_binvars = MIP_model.getNBinVars()
        print("N of variables: {}".format(n_vars))
        print("N of binary vars: {}".format(n_binvars))
        print("N of constraints: {}".format(MIP_model.getNConss()))

        start_time = time.time()
        state, reward, solution_map= self.solve_initial(self.state, self.actions['reset'], MIP_model)
        end_time = time.time()
        eplased_time = end_time - start_time
        return eplased_time, self.obj_now

    def evaluate_agent(self, train_instance_size='-small', lr=0.001, epsilon=0, n_epochs=20, n_batches=60):
        train_instance_type = self.instance_type
        
        direc = './data/generated_instances/' + train_instance_type + '/' + train_instance_size + '/'

        instances_directory = direc + 'transformedmodel' + '/'
        sols_directory = direc + 'firstsol' + '/'
        test_dataset_first = self.load_test_mip_dataset(instances_directory=instances_directory, sols_directory=sols_directory, incumbent_mode='firstsol')
        test_loader = DataLoader(test_dataset_first, shuffle=False, batch_size=1, collate_fn=custom_collate)
        size_testset = len(test_loader.dataset)
        print('size_testset', size_testset)

        self.output_file = './output_data/' + train_instance_type + train_instance_size + str(self.each_time_available) + '.csv'
        rl_policy = GNNPolicy()
        
        checkpoint = torch.load(self.model_N_path)
        rl_policy.load_state_dict(checkpoint['model_state_dict'])
        rl_policy.train()

        optim_N = torch.optim.Adam(rl_policy.parameters(), lr=lr)
        optim_N.load_state_dict(checkpoint['optimizer_state_dict'])

        greedy = False
        rl_policy = rl_policy.to(self.device)
        agent_N = AgentReinforce(rl_policy, self.device, greedy, optim_N, 0.0)
        
        returns = []
        epochs = []
        data = None
        
        epoch_init = 0
        epoch_start = epoch_init  
        epoch_end = epoch_start+n_epochs+1
        optimizer_N = optim_N
        for epoch in range(epoch_start,epoch_end):
            
            print(f"Epoch {epoch}")
            for batch in test_loader:
                MIP_model = batch['mip_model'][0]
                MIP_model.writeProblem('model.cip')
                
                MIP_model.setParam('display/verblevel', 0)
                self.total_time_available = self.total_time_available_input
                self.incumbent_solution = batch['incumbent_solution'][0]
               
                self.devide_groups(MIP_model)
            
                self.n_binvars = MIP_model.getNBinVars()
                self.default_n_groups = int(len(self.groups)/2) 
                self.n_groups = int(len(self.groups)/2)
                agent_N, eplased_time, obj = self.evaluate_agent_per_instance(agent_N, optimizer_N, MIP_model)
                print('Solving time:', eplased_time) 
                print('Objective value:', obj)
                new_data = [eplased_time, obj]
                with open(self.output_file, mode="a", newline="", encoding="utf-8") as file:
                    writer = csv.writer(file)
                    writer.writerow(new_data) 
    def mip_to_graph(self, MIP_model):
        instance = ecole.scip.Model.from_pyscipopt(MIP_model)
        observation, action_set, reward_offset, done, info = self.env.reset(instance)
        
        variable_features = np.concatenate((observation.variable_features, self.state[:, np.newaxis]), axis=1)
        variable_features = np.concatenate((observation.variable_features, self.records_var_branch[:, np.newaxis]), axis=1)
      
        graph = BipartiteNodeData(observation.constraint_features,
                                    observation.edge_features.indices,
                                    observation.edge_features.values,
                                    variable_features)

        graph.num_nodes = observation.constraint_features.shape[0] + \
                            observation.variable_features.shape[
                                0]
        print(graph)
        return graph
  

    def train_agent_2(self, train_instance_size='-small', lr=0.001, epsilon=0, n_epochs=20, n_batches=60, continue_train=False):
        train_instance_type = self.instance_type
        
        direc = './data/generated_instances/' + train_instance_type + '/' + train_instance_size + '/'

        instances_directory = direc + 'transformedmodel' + '/'
        sols_directory = direc + 'firstsol' + '/'
        train_dataset_first, valid_dataset_first, test_dataset_first = self.load_mip_dataset(instances_directory=instances_directory, sols_directory=sols_directory, incumbent_mode='firstsol')
        train_datasets = [train_dataset_first] 
        train_dataset = ConcatDataset(train_datasets)
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=custom_collate)
        size_trainset = len(train_loader.dataset)
        print('size_trainset', size_trainset)

        if continue_train:
            rl_policy = GNNPolicy()
            checkpoint = torch.load(self.model_N_path)
            rl_policy.load_state_dict(checkpoint['model_state_dict'])
            rl_policy.train()

            optim_N = torch.optim.Adam(rl_policy.parameters(), lr=lr)
            optim_N.load_state_dict(checkpoint['optimizer_state_dict'])

            greedy = False
            rl_policy = rl_policy.to(self.device)
            # agent_N = AgentReinforce(rl_policy, device, greedy, optim_N, epsilon)
            agent_N = AgentReinforce(rl_policy, self.device, greedy, optim_N, 0.0)
        else:

            rl_policy = GNNPolicy()
            device = self.device
            rl_policy.train()
            optim_N = torch.optim.Adam(rl_policy.parameters(), lr=lr)

            greedy = False
            rl_policy = rl_policy.to(device)
            agent_N = AgentReinforce(rl_policy, device, greedy, optim_N, epsilon)
        
            
        returns = []
        epochs = []
        data = None
        epoch_init = 0
        epoch_start = epoch_init  
        epoch_end = epoch_start+n_epochs
        optimizer_N = optim_N
        for epoch in range(epoch_start,epoch_end):
            del data
            print(f"Epoch {epoch}")

            return_epoch = 0
            for batch in train_loader:
                MIP_model = batch['mip_model'][0]
                
                MIP_model.writeProblem('model.cip')
                MIP_model.setParam('display/verblevel', 0)
                
                self.incumbent_solution = batch['incumbent_solution'][0]
                self.devide_groups(MIP_model)
                self.n_binvars = MIP_model.getNBinVars()
                self.default_n_groups = int(len(self.groups)/2) 
                self.n_groups = int(len(self.groups)/2) 
                print(MIP_model.getProbName())
                agent_N, eplased_time = self.train_agent_per_instance(agent_N, optimizer_N, MIP_model)
                agent_N, optimizer_N, reward = self.update_agent(agent_N, optimizer_N)
                return_epoch += reward  
            

            returns.append(return_epoch)
            epochs.append(epoch)

            print('Epoch round:', epoch)
            print(f"Return: {return_epoch:0.6f}")
            print('Time used:', eplased_time)


            data = [epochs, returns,  eplased_time]

            
        filename = f'results3/diary.pkl'
        with gzip.open(filename, 'wb') as f:
            pickle.dump(data, f)

        # save checkpoint
        torch.save({'model_state_dict': rl_policy.state_dict(),
                    'optimizer_state_dict': optimizer_N.state_dict(),
                    'loss_data':data,},self.model_N_path  )
        



warnings.filterwarnings("ignore")

instance_type = instancetypes[0]

instance_size = instancesizes[1]
incumbent_mode = 'firstsol'
lbconstraint_mode = 'symmetric'
seed = 0
stop_noimprove_num = 20
total_time_available = 300 
each_time_available = 20 


choose_group = RLChooseGroup(addBranching(), instance_type, instance_size, lbconstraint_mode, incumbent_mode, seed, total_time_available, each_time_available, stop_noimprove_num)
# choose_group.train_agent_2(n_epochs=1, continue_train=False)
output_file = instance_type + instance_size + str(each_time_available)
choose_group.evaluate_agent(n_epochs=1, train_instance_size=instance_size)





