import os
import sys
import importlib
import argparse
import csv
import time
import pickle
import pyscipopt as scip

import torch
import pandas as pd
import pathlib


def scip_main(problem,difficulty,lp_path, time_limit,tmp_instance,seed,method):
    '''
    :param lp_path: string: lp所处的位置：比如说'data/instances/cauctions/'
    :param problem: string: 问题类型：'setcover', 'cauctions', 'facilities', 'indset', 'item_placement
    :param method: string: 使用求解器的类型：'GCNN' 或者 'SCIP_default'
    :param time_limit: int: 设置求解器求解的时间限制
    :param gap_limit: float:  设置求解器求解的gap限制
    :param limit_para: string:  测试问题类型：限制时间 则设置'time',限制gap则设置'gap'
    :return:
    :param result_df: 返回求解结果的csv dataframe
    '''
    gap_limit = 0.05
    limit_para = 'time'
    argsGpu = -1
    instances = []
    glo_time = 3600
    argtask = 'dual'
    ## OUTPUT
    device = "CPU" if argsGpu == -1 else "GPU"
    # result_file = f"{method}_{mode}_{time_limit}s_{device}_{time.strftime('%Y%m%d-%H%M%S')}.csv"
    if limit_para == 'time':
        result_file = f"{problem}_{method}_time_{time_limit}s.csv"
    elif limit_para == 'gap':
        result_file = f"{problem}_{method}_gap_{gap_limit}.csv"
    eval_dir = f"results/{argtask}/"+ problem + "_" + difficulty+"/"
    os.makedirs(eval_dir, exist_ok=True)
    result_file = f"{eval_dir}/{result_file}"

    instances = [{'type': 'original', 'path': str(tmp_instance)}]

    print(f"problem: {problem}")
    print(f"gpu: {argsGpu}")
    print(f"time limit: {time_limit} s")

    if argsGpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        device = torch.device("cpu")
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{argsGpu}'
        device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

    print("running SCIP...")

    fieldnames = [
        'problem',
        'policy',
        'seed',
        'instance',
        'nnodes',
        'stime',
        'gap',
        'dualbound',
        'objVal',
        'status'
    ]

    with open(result_file, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        with open(result_file, "r", newline="") as f:
            reader = csv.reader(f)
            if not [row for row in reader]:
                writer.writeheader()

        for instance in instances:
            print(f"{instance['type']}: {instance['path']}...")

            # get_scip(instance, time_limit, csvfile)
            torch.manual_seed(seed)

            m = scip.Model()
            m.setIntParam('display/verblevel', 0)
            m.readProblem(f"{instance['path']}")
            if problem != 'item_placement':
                m.setHeuristics(scip.SCIP_PARAMSETTING.OFF)
                m.setIntParam('limits/restarts', 0)

            m.setIntParam('timing/clocktype', 2)  # 1: CPU user seconds, 2: wall clock time
            m.setBoolParam("randomization/permuteconss", True)
            m.setBoolParam('randomization/permutevars', True)
            m.setIntParam('randomization/permutationseed', seed)
            m.setIntParam('randomization/randomseedshift', seed)
            m.setIntParam("randomization/lpseed", seed)

            if method == 'fullstrong':
                m.setIntParam('branching/allfullstrong/priority',666666)
            elif method == 'relpscost':
                m.setIntParam('branching/relpscost/priority', 666666)

            if limit_para == 'time':
                m.setRealParam('limits/time', time_limit)
            elif limit_para == 'gap':
                m.setRealParam('limits/gap', gap_limit)
                m.setRealParam('limits/time', glo_time)

            abstime = time.time()
            walltime = time.perf_counter()
            proctime = time.process_time()

            m.optimize()

            abstime = time.time() - abstime
            walltime = time.perf_counter() - walltime
            proctime = time.process_time() - proctime

            stime = m.getSolvingTime()
            nnodes = m.getNNodes()
            nlps = m.getNLPs()
            gap = m.getGap()
            status = m.getStatus()
            #                 ndomchgs = brancher.ndomchgs
            #                 ncutoffs = brancher.ncutoffs
            ndomchgs = 0
            ncutoffs = 0
            objVal = m.getObjVal()
            dualbound = m.getDualbound()

            writer.writerow({
                'problem': problem,
                'policy': method,
                'seed': seed,
                'instance': instance['path'],
                'nnodes': nnodes,
                'stime': stime,
                'gap': gap,
                'dualbound':dualbound,
                'objVal': objVal,
                'status': status
            })

            csvfile.flush()
            m.freeProb()

            print(f" {nlps} lps {stime:.2f} ({walltime:.2f} wall {proctime:.2f} proc) s. {status}")


if __name__ == '__main__':
    problem = '2_load_balancing'
    difficulty = 'hard'
    time_limit = 900
    lp_path = '../../mlsolverresearch/instances/' + problem + '/test/'
    seed = 0
    instances_path = pathlib.Path(lp_path)
    instance_files = [test for test in list(instances_path.glob('*.mps.gz'))] # lp
    for seed,tmp_instance in enumerate(instance_files):
        scip_main(problem, difficulty, lp_path, time_limit, tmp_instance, seed, method='scip_default')
