from sotaRLagent import Agent, PseudocostBranchingAgent, StrongBranchingAgent

import torch
import numpy as np
import glob
import time
import pathlib
import csv
import shutil


def SOTARL_Main(argsproblem, argsdifficulty, lp_path, time_limit, tmp_instance, seed):
    agent_name = 'retro'
    path_to_load_agent = 'retro_branching_paper_validation_agents'
    device = 'cuda:0'

    # initialise the agent
    agents = {}
    if agent_name not in set(['pseudocost_branching', 'strong_branching', 'scip_branching']):
        # is an ML agent
        path = path_to_load_agent + f'/{argsproblem}/{agent_name}/'
        config = path + 'config.json'
        agent = Agent(device=device, config=config, name=agent_name)
        for network_name, network in agent.get_networks().items():
            if network is not None:
                try:
                    # see if network saved under same var as 'network_name'
                    agent.__dict__[network_name].load_state_dict(
                        torch.load(path + f'/{network_name}_params.pkl', map_location=device))
                except KeyError:
                    # network saved under generic 'network' var (as in Agent class)
                    agent.__dict__['network'].load_state_dict(
                        torch.load(path + f'/{network_name}_params.pkl', map_location=device))
            else:
                print(f'{network_name} is None.')
        agent.eval()  # put in test mode
    else:
        # is a standard heuristic
        device = 'cpu'
        agent = agent_name

    # run the agent on the validation instances
    start = time.time()

    '''
    Cannot pickle ecole objects, so if agent is e.g. 'strong_branching' or 'pseudocost_branching', need to give agent as str so
    can initialise inside this ray remote function.
    '''
    argtask = 'dual'
    argsdebug = False

    # collect the instance files
    if argsproblem == 'item_placement':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/1_item_placement_SOTARL.csv")

    elif argsproblem == 'cauctions':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/cauctions_" + argsdifficulty + "/cauctions_SOTARL.csv")

    elif argsproblem == 'indset':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/indset_" + argsdifficulty + "/indset_SOTARL.csv")

    elif argsproblem == 'setcover':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/setcover_" + argsdifficulty + "/setcover_SOTARL.csv")

    elif argsproblem == 'facilities':
        instances_path = pathlib.Path(lp_path)
        results_file = pathlib.Path(f"results/{argtask}/facilities_" + argsdifficulty + "/facilities_SOTARL.csv")

    print(f"Processing instances from {instances_path.resolve()}")

    results_file.parent.mkdir(parents=True, exist_ok=True)
    results_fieldnames = ['instance', 'seed', 'dual_bound', 'primal_bound',
                          'objective_offset', 'cumulated_reward', 'solvingtime', 'nnodes']

    start = time.time()

    if type(agent) == str:
        if agent == 'pseudocost_branching':
            agent = PseudocostBranchingAgent(name='pseudocost_branching')
        elif agent == 'strong_branching':
            agent = StrongBranchingAgent(name='strong_branching')
        elif agent == 'scip_branching':
            class SCIPBranchingAgent:
                def __init__(self):
                    self.name = 'scip_branching'

            agent = SCIPBranchingAgent()
        else:
            raise Exception(f'Unrecognised agent str {agent}, cannot initialise.')

    instance_files = [tmp_instance]

    if argtask == "primal":
        from rewards import TimeLimitPrimalIntegral as BoundIntegral

    elif argtask == "dual":
        from rewards import TimeLimitDualIntegral as BoundIntegral

    elif argtask == "config":
        from rewards import TimeLimitPrimalDualIntegral as BoundIntegral

    if argtask == "dual":
        from environments import BranchingOpen as Environment
        memory_limit = 8796093022207  # maximum

    from sotaRLagent import NodeBipariteWith43VariableFeatures
    observation_function = (NodeBipariteWith43VariableFeatures())

    # evaluation loop
    for instance in instance_files:
        tmp_instance_name = str(instance).split('/')[-1].split('.')[0]

        if argsproblem == 'indset' or argsproblem == 'cauctions':
            integral_function = BoundIntegral()
        else:
            integral_function = -BoundIntegral()

        env = Environment(
            time_limit=time_limit,
            observation_function=observation_function,
            scip_params={'limits/memory': memory_limit,
                         'separating/maxrounds': 0,
                         'presolving/maxrestarts': 0},
            reward_function=integral_function
        )

        env.seed(seed)
        print('Initialised env.')

        objective_offset = 0

        print()
        print(f"Instance {tmp_instance_name}")
        print(f"  seed: {seed}")
        print(f"  objective offset: {objective_offset}")

        # reset the environment
        observation, action_set, reward, done, info = env.reset(str(instance))

        if argsdebug:
            print(f"  info: {info}")
            print(f"  reward: {reward}")
            print(f"  action_set: {action_set}")

        cumulated_reward = 0  # discard initial reward

        # loop over the environment
        while not done:
            obs = (torch.from_numpy(observation.row_features.astype(np.float32)).to(device),
                   torch.LongTensor(observation.edge_features.indices.astype(np.int16)).to(device),
                   torch.from_numpy(observation.edge_features.values.astype(np.float32)).to(device).unsqueeze(1),
                   torch.from_numpy(observation.column_features.astype(np.float32)).to(device))
            action_set = action_set.astype(
                int)  # ensure action set is int so gets correctly converted to torch.LongTensor later

            action, action_idx = agent.action_select(action_set=action_set, obs=obs, munchausen_tau=0, epsilon=0,
                                                     model=env.model, done=done, agent_idx=0)

            if argsdebug:
                print(f"  action: {action}")

            observation, action_set, reward, done, info = env.step(action)

            if argsdebug:
                print(f"  info: {info}")
                print(f"  reward: {reward}")
                print(f"  action_set: {action_set}")

            cumulated_reward += reward

        print(f"  cumulated reward (to be maximized): {cumulated_reward}")

        # save instance results
        with open(results_file, mode='a', newline='') as csv_file:
            writer = csv.DictWriter(csv_file, fieldnames=results_fieldnames)
            with open(results_file, "r", newline="") as f:
                reader = csv.reader(f)
                if not [row for row in reader]:
                    writer.writeheader()

            writer.writerow({
                'instance': str(instance),
                'seed': seed,
                'dual_bound': info['dual_bound'],
                'primal_bound': info['primal_bound'],
                'objective_offset': objective_offset,
                'cumulated_reward': cumulated_reward,
                'solvingtime': info['solvingtime'],
                'nnodes': info['nnodes']
            })

    end = time.time()

    print(f'Finished validating agent {agent_name} in {end - start:.3f} s.')


if __name__ == '__main__':
    SOTARL_Main()
