import glob
from itertools import cycle
import logging
import random as rd

import pyscipopt  # noqa: F401
import ecole
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter  # noqa: F401

from experiments.params.learner import real_test_params
from experiments.params.retro import unit_retrospective_trajectory_params
from experiments.utils import sort_instances
from rl.agents.tree_dqn import TreeDQNAgent
from rl.environments.ecole import Branching
from rl.learners.tree_dqn import TreeDQNLearner
from rl.networks.gcnn import GNNParsonsonPolicy
from rl.observations.node_bipartite import ParsonsonNodeBipartite
from rl.observations.node_depth import NodeDepth
from rl.rewards.retro_branching import RetroBranching
from rl.rewards.reward_agent import RewardAgent, TreeMDPRewardAgent


def seed_torch(seed):
    torch.manual_seed(seed)
    if torch.backends.cudnn.enabled:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


def run(
    agent: str,
    instance_path: str,
    log_path: str,
):
    logging.basicConfig(filename=log_path, format="%(asctime)s %(message)s", level=logging.DEBUG)
    logging.info("\n\n\n")
    logging.info("~" * 89)
    logging.info("~" * 40 + " NEW RUN " + "~" * 40)
    logging.info("~" * 89)
    logging.info("\n\n\n")

    # Learner params
    learner_params = real_test_params

    # Retro branching
    if agent == "dqn_retro":
        learner_params["gamma"] = 0.99
        retro_params = unit_retrospective_trajectory_params
        reward_agent = RetroBranching(**retro_params)
        classification = False
        DFS = False

    elif agent =="dqn_bbmdp":
        reward_agent = RewardAgent()
        classification = True
        DFS = True

    else:
        reward_agent = TreeMDPRewardAgent()
        classification = False
        DFS = True

    # seeding
    seed = 42
    rd.seed(seed)
    np.random.seed(seed)  # noqa: NPY002
    seed_torch(seed)

    # DQN value network
    final_invert_activation = True
    value_network = GNNParsonsonPolicy(final_invert_activation, classification=classification)
    target_network = GNNParsonsonPolicy(final_invert_activation, classification=classification)
    save_path = f"../data/trained_models/{instance_path}/train/{agent}_"

    # DQN agent
    agent = TreeDQNAgent(value_network, target_network, classification=classification)
    logging.info("Initialised DQN agent.")

    # initialise instance generator
    if instance_path == "SetCovering":
        instances = ecole.instance.SetCoverGenerator()
    elif instance_path == "CombinatorialAuction":
        instances = ecole.instance.CombinatorialAuctionGenerator()
    elif instance_path == "MaximumIndependentSet":
        instances = ecole.instance.IndependentSetGenerator()
    else:
        instances = cycle(glob.glob(f"../data/instances/{instance_path}/easy/*.lp"))

    evaluation_instances = sort_instances(glob.glob(f"../data/instances/{instance_path}/easy/*.lp"))[:20]

    # initialise branch-and-bound environment
    static_params = {"node_selection": "dfs"} if DFS else {}

    env = Branching(
        observation_function=ParsonsonNodeBipartite(),
        information_function={
            "time": ecole.reward.SolvingTime(wall=True),
            "n_nodes": ecole.reward.NNodes(),
            "primal_dual_integral": ecole.reward.PrimalDualIntegral(),
            "depth": NodeDepth(),
        },
        reward_function=-ecole.reward.NNodes(),
        static_params=static_params,
    )
    env.seed(seed)
    logging.info("Initialised environment.")

    # Tensorboard, uncomment to deactivate
    writer = SummaryWriter() # None

    # initialise DQN learner
    learner = TreeDQNLearner(
        agent=agent,
        reward_agent=reward_agent,
        env=env,
        instances=instances,
        evaluation_instances=evaluation_instances,
        save_path=save_path,
        writer=writer,
        **learner_params,
    )
    logging.info("Initialised learner.")
    # train the DQN agent
    logging.info("Training DQN agent...\n")
    learner.train()


if __name__ == "__main__":
    import argparse

    # Get parameters
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark", default=None, type=str)
    parser.add_argument("--agent", default=None, type=str)
    args = parser.parse_args()
    log_path = f"../experiments/logs/{args.benchmark}/training_{args.agent}.log"
    run(
        agent=args.agent,
        instance_path=args.benchmark,
        log_path=log_path,
    )
