import argparse
import csv
import json
import pathlib

import ecole as ec
import numpy as np
import pandas as pd


class ExploreThenStrongBranch:
    def __init__(self):
        self.strong_branching_function = ec.observation.StrongBranchingScores()

    def before_reset(self, model):
        self.strong_branching_function.before_reset(model)

    def extract(self, model, done):
        return self.strong_branching_function.extract(model, done)


def SCIPSB_Main(argsproblem, argsdifficulty,lp_path,time_limit,tmp_instance,seed):
    '''
    :param argsproblem: 'item_placement',"setcover","cauctions","indset","facilities"
    :param argsdifficulty: 'easy', "medium", "hard"
    :return:
    '''
    argtask = 'dual'
    argsdebug = False

    # check the Ecole version installed
    assert ec.__version__ == "0.7.3", "Wrong Ecole version."

    print(f"Evaluating the {argtask} task agent on the {argsproblem} problem.")

    # collect the instance files
    if argsproblem == 'item_placement':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/1_item_placement_SCIPSB.csv")

    elif argsproblem == 'cauctions':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/cauctions_"+argsdifficulty+"/cauctions_SCIPSB.csv")

    elif argsproblem == 'indset':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/indset_"+argsdifficulty+"/indset_SCIPSB.csv")

    elif argsproblem == 'setcover':
        instances_path = pathlib.Path(lp_path)  # different names of easy/medium/hard
        results_file = pathlib.Path(f"results/{argtask}/setcover_"+argsdifficulty+"/setcover_SCIPSB.csv")

    elif argsproblem == 'facilities':
        instances_path = pathlib.Path(lp_path)  # different names of easy/medium/hard
        results_file = pathlib.Path(f"results/{argtask}/facilities_"+argsdifficulty+"/facilities_SCIPSB.csv")

    print(f"Processing instances from {instances_path.resolve()}")
    # 5 seed for every instance
    instance_files = [tmp_instance]

    print(f"Saving results to {results_file.resolve()}")
    results_file.parent.mkdir(parents=True, exist_ok=True)
    results_fieldnames = ['instance', 'seed','dual_bound','primal_bound',
                          'objective_offset', 'cumulated_reward','solvingtime','nnodes']

    import sys

    sys.path.insert(1, str(pathlib.Path.cwd()))

    # set up the proper agent, environment and goal for the task
    if argtask == "dual":
        if argsproblem in ['item_placement','load_balancing','anonymous']:
            # ablation study
            if argsproblem == 'item_placement':
                from environments import BranchingOpen as Environment
            else:
                from environments import Branching as Environment
        else:
            from environments import Branching as Environment

        memory_limit = 8796093022207  # maximum


    if argtask == "primal":
        from rewards import TimeLimitPrimalIntegral as BoundIntegral

    elif argtask == "dual":
        from rewards import TimeLimitDualIntegral as BoundIntegral

    elif argtask == "config":
        from rewards import TimeLimitPrimalDualIntegral as BoundIntegral

    # evaluation loop
    for instance in instance_files:
        tmp_instance_name = str(instance).split('/')[-1].split('.')[0]
        observation_function = {
            "scores": ExploreThenStrongBranch(),
            "node_observation": ec.observation.NodeBipartite(),
        }

        if argsproblem == 'indset' or argsproblem == 'cauctions':
            integral_function = BoundIntegral()
        else:
            integral_function = -BoundIntegral()

        env = Environment(
            time_limit=time_limit,
            observation_function=observation_function,
            scip_params={'limits/memory': memory_limit},
            reward_function=integral_function
        )

        # seed both the agent and the environment (deterministic behavior)
        env.seed(seed)

        objective_offset = 0

        print()
        print(f"Instance {tmp_instance_name}")
        print(f"  seed: {seed}")
        print(f"  objective offset: {objective_offset}")

        # reset the environment
        observation, action_set, reward, done, info = env.reset(str(instance))

        if argsdebug:
            print(f"  info: {info}")
            print(f"  reward: {reward}")
            print(f"  action_set: {action_set}")

        cumulated_reward = 0  # discard initial reward

        # loop over the environment
        while not done:
            scores = observation["scores"]
            action = action_set[scores[action_set].argmax()]

            if argsdebug:
                print(f"  action: {action}")

            observation, action_set, reward, done, info = env.step(action)

            if argsdebug:
                print(f"  info: {info}")
                print(f"  reward: {reward}")
                print(f"  action_set: {action_set}")

            cumulated_reward += reward

        print(f"  cumulated reward (to be maximized): {cumulated_reward}")

        # save instance results
        with open(results_file, mode='a', newline='') as csv_file:
            writer = csv.DictWriter(csv_file, fieldnames=results_fieldnames)
            with open(results_file, "r", newline="") as f:
                reader = csv.reader(f)
                if not [row for row in reader]:
                    writer.writeheader()

            writer.writerow({
                'instance': str(instance),
                'seed': seed,
                'dual_bound': info['dual_bound'],
                'primal_bound': info['primal_bound'],
                'objective_offset': objective_offset,
                'cumulated_reward': cumulated_reward,
                'solvingtime': info['solvingtime'],
                'nnodes': info['nnodes']
            })