import importlib
import argparse
import csv
import time

import gurobipy as gp
from gurobipy import *


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        choices=['setcover', 'cats', 'facilities', 'indset', 'maxcut', 'item', 'miplib', 'hybrid', 'cats_trans', 'trans_big'],
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=2,
    )
    args = parser.parse_args()

    result_file = "{}_{}.csv".format(args.problem,time.strftime('%Y%m%d-%H%M%S'))
    instances = []
    seeds = [0]
    gcnn_models = ['baseline']
    other_models = ['extratrees_gcnn_agg', 'lambdamart_khalil', 'svmrank_khalil']
    internal_branchers = ['relpscost']
    time_limit = 1000

    if args.problem == 'maxcut':
#        instances += [{'type': 'big', 'path': "data/instances/test_4950_2975/instance_{}.lp".format(i+121)} for i in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/test_4950_2975/instance_168.lp"}]
        time_limit = 500
    elif args.problem == 'setcover':
        # instances += [{'type': 'big', 'path': "data/instances/setcover/transfer_5000r_4000c_0.05d/instance_{}.lp".format(i+1)} for i in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/setcover/test_5000r_1000c_0.05d/instance_{}.lp".format(i+101)} for i in range(50)]
    elif args.problem == 'item':
        instances += [{'type': 'big', 'path': "data/instances/item_placement/test/item_placement_{}.mps.gz".format(i + 10000)} for i in range(30)]
    elif args.problem == 'miplib':
#        instances += [{'type': 'big', 'path': "data/instances/anonymous/test/anonymous_{}.mps.gz".format(i + 119)} for i in range(20)]
        instances += [{'type': 'big', 'path': "data/instances/anonymous/test/anonymous_126.mps.gz"}]
        time_limit = 1800
    elif args.problem == 'cats':
        instances += [{'type': 'big', 'path': "data/instances/cauctions/test_2000_4000/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/indset/test_1500_4/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/indset/transfer_6000_4/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
    elif args.problem == 'hybrid':
        instances += [{'type': 'big', 'path': "data/instances/transfer_9950_5975/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/transfer_19950_11975/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/setcover/transfer_5000r_2000c_0.05d/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/indset/transfer_3000_4/instance_{}.lp".format(i + 1)} for
                      i in range(50)]
    elif args.problem == 'cats_trans':
        instances += [{'type': 'big', 'path': "data/instances/cauctions/transfer_4000_8000/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/cauctions/transfer_8000_16000/instance_{}.lp".format(i + 1)} for i
                      in range(15)]
    elif args.problem == 'trans_big':
        instances += [{'type': 'big', 'path': "data/instances/transfer_19950_11975/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/cauctions/transfer_4000_8000/instance_{}.lp".format(i + 1)} for i
                      in range(50)]
        instances += [{'type': 'big', 'path': "data/instances/cauctions/transfer_8000_16000/instance_{}.lp".format(i + 1)} for i
                      in range(15)]
    else:
        raise NotImplementedError

    branching_policies = []

    # SCIP internal brancher baselines
    for brancher in internal_branchers:
        for seed in seeds:
            branching_policies.append({
                    'type': 'internal',
                    'name': brancher,
                    'seed': seed,
             })

    print("problem: {}".format(args.problem))
    print("gpu: {}".format(args.gpu))
    print("time limit: {} s".format(time_limit))

    ### TENSORFLOW SETUP ###
    if args.gpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(args.gpu)

    # load and assign tensorflow models to policies (share models and update parameters)
    loaded_models = {}
    for policy in branching_policies:
        if policy['type'] == 'gcnn':
            if policy['name'] not in loaded_models:
                sys.path.insert(0, os.path.abspath("models/{}".format(policy['name'])))
                import model
                importlib.reload(model)
                loaded_models[policy['name']] = model.GCNPolicy()
                del sys.path[0]
            policy['model'] = loaded_models[policy['name']]

    print("running Gurobi...")

    fieldnames = [
        'policy',
        'seed',
        'type',
        'instance',
        'nnodes',
        'nlps',
        'stime',
        'gap',
        'status',
        'ndomchgs',
        'ncutoffs',
        'walltime',
        'proctime',
        'obj',
    ]

    os.makedirs('gurobi_results', exist_ok=True)
    with open("gurobi_results/{}".format(result_file), 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for instance in instances:
            print("{}: {}...".format(instance['type'],instance['path']))
            for policy in branching_policies:
                # tf.set_random_seed(policy['seed'])

                model = read("{}".format(instance['path']))

                model.Params.timeLimit = time_limit

                walltime = time.perf_counter()
                proctime = time.process_time()
          
                model.optimize() 

                walltime = time.perf_counter() - walltime
                proctime = time.process_time() - proctime
                      
                stime = walltime
                nnodes = 0
                nlps = 0
                gap = -1
                status = ''
                ndomchgs = 0
                ncutoffs = 0
                obj = model.getObjective().getValue()

                writer.writerow({
                    'policy': "{}:{}".format(policy['type'],policy['name']),
                    'seed': policy['seed'],
                    'type': instance['type'],
                    'instance': instance['path'],
                    'nnodes': nnodes,
                    'nlps': nlps,
                    'stime': stime,
                    'gap': gap,
                    'status': status,
                    'ndomchgs': ndomchgs,
                    'ncutoffs': ncutoffs,
                    'walltime': walltime,
                    'proctime': proctime,
                    'obj': obj,
                })

                csvfile.flush()

                print("  {}:{} {} - {} ({}) nodes {} lps {:.2} ({:.2} wall {:.2} proc) s. {}".format(policy['type'],policy['name'],policy['seed'],nnodes,nnodes+2*(ndomchgs+ncutoffs),nlps,stime,walltime,proctime,status))
