from pyscipopt import Model
import glob
import os
import csv
import numpy as np
import time
import argparse


def learn(args):
    test_time_limit = 200  # default:200

    if args.problem == 'maxcut':
        instances_valid = glob.glob('data/instances/test_4950_2975/*.lp')
    #        instances_valid += glob.glob('data/instances/transfer_9950_5975/*.lp')
    #         instances_valid += glob.glob('data/instances/transfer_19950_11975/*.lp')

    elif args.problem == 'cats':
        instances_valid = glob.glob('data/instances/cauctions/test_2000_4000/*.lp')
    #        instances_valid += glob.glob('data/instances/cauctions/transfer_4000_8000/*.lp')
    #        instances_valid += glob.glob('data/instances/cauctions/transfer_8000_16000/*.lp')

    elif args.problem == 'indset':
        instances_valid = glob.glob('data/instances/indset/test_1500_4/*.lp')
        instances_valid += glob.glob('data/instances/indset/transfer_6000_4/*.lp')
        instances_valid += glob.glob('data/instances/indset/transfer_3000_4/*.lp')

    elif args.problem == 'setcover':
        instances_valid = ["data/instances/setcover/test_5000r_1000c_0.05d/instance_{}.lp".format(i + 101) for i in
                           range(50)]
        instances_valid += glob.glob('data/instances/setcover/transfer_5000r_4000c_0.05d/*.lp')  # transfer
        instances_valid += glob.glob('data/instances/setcover/transfer_5000r_2000c_0.05d/*.lp')  # transfer

    elif args.problem == 'item':
        instances_valid = ["data/instances/item_placement/test/item_placement_{}.mps.gz".format(i + 10000) for i in
                           range(100)]

    elif args.problem == 'miplib':
        # instances_valid += ["data/instances/anonymous/test/anonymous_{}.mps.gz".format(i + 119) for i in range(20)]
        instances_valid = ["data/instances/anonymous/test/anonymous_126.mps.gz"]
        # instances_valid = ["data/instances/anonymous/valid/anonymous_102.mps.gz"]

    else:
        raise NotImplementedError

    all_disable = ['heuristics/crossover/freq', 'heuristics/dins/freq', 'heuristics/gins/freq', 'heuristics/localbranching/freq',
                   'heuristics/alns/freq', 'heuristics/ofins/freq', 'heuristics/rens/freq', 'heuristics/rins/freq']

    disable = ['heuristics/dins/freq', 'heuristics/gins/freq', 'heuristics/rens/freq', 'heuristics/rins/freq']
    for epoch in range(1):

        fieldnames = [
            'instance',
            'obj',
            'initial',
            'bestroot',
            'imp',
            'mean',
            'time',
            'Integral',
        ]
        result_file = "{}_{}.csv".format(args.problem, time.strftime('%Y%m%d-%H%M%S'))
        # result_file = "{}_{}.csv".format(args.problem, str(variable_to_branching))
        os.makedirs('ddpg_test_results', exist_ok=True)
        with open("ddpg_test_results/{}".format(result_file), 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

            for cyc in range(len(instances_valid)):
                model = Model()
                model.readProblem(instances_valid[cyc])

                model.setRealParam('limits/time', test_time_limit)

                model.setBoolParam('randomization/permutevars', True)
                model.setIntParam('separating/maxroundsroot', 0)

                # separation only at root node
                model.setIntParam('separating/maxrounds', 0)

                # no restart
                model.setIntParam('presolving/maxrestarts', 0)

                for para in all_disable:
                    model.setParam(para, -1)

                # model.setParam('heuristics/dins/freq', 1)
                # model.setParam('heuristics/gins/freq', 1)
                # model.setParam('heuristics/rens/freq', 1)
                model.setParam('heuristics/rins/freq', 1)

                model.optimize()

                writer.writerow({
                    'instance': instances_valid[cyc],
                    'obj': model.getPrimalbound(),
                    'initial': 0,
                    'bestroot': 0,
                    'imp': 0,
                    'mean': 0,
                    'time': test_time_limit,
                    'Integral': 0
                })
                csvfile.flush()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        choices=['setcover', 'cats', 'facilities', 'indset', 'maxcut', 'item', 'miplib'],
    )

    arg = parser.parse_args()
    # tf.enable_eager_execution()

    learn(args=arg)  # is_maximum参数，1表示最大问题，0表示最小问题
