import glob
import logging
import random as rd
from time import time

import pyscipopt
import ecole
import numpy as np
import torch

from pyscipopt import scip
from scipy.stats import gmean, gstd

from rl.agents.tree_dqn import TreeDQNAgent
from rl.environments.ecole import Branching
from rl.learners.tree_dqn import extract_state_from_obs
from rl.networks.gcnn import GNNParsonsonPolicy, GNNPolicy
from rl.observations.node_bipartite import ParsonsonNodeBipartite



class BenchmarkRunner:
    def __init__(
        self,
        agent: str,
        instance_dict: dict,
        benchmark: str,
        DEVICE: torch.device,
        seed: int = 0,
    ) -> None:
        self.agent = agent
        self.benchmark = benchmark
        self.pyscipopt_model = pyscipopt.Model("reoptimization")
        self.pyscipopt_model.setHeuristics(scip.PY_SCIP_PARAMSETTING.DEFAULT)  # activate heuristics
        self.pyscipopt_model.setParam("separating/maxrounds", 0)  # activate cutting planes only at root node
        # self.pyscipopt_model.setParam("separating/maxroundsroot", 0)
        self.pyscipopt_model.setParam("estimation/restarts/restartpolicy", "n")  # disable restarts
        self.pyscipopt_model.setParam("limits/time", 3600)
        self.instance_dict = instance_dict
        self.DEVICE = DEVICE
        self.seed = seed
        self.random = False

    def build_env(self, observation_function = None) -> ecole.environment:
        observation_function = (
            ecole.observation.NodeBipartite() if observation_function is None else observation_function
        )
        env: ecole.environment = Branching(
            observation_function=observation_function,
            information_function={
                "time": ecole.reward.SolvingTime(wall=True),
                "n_nodes": ecole.reward.NNodes(),
            },
            static_params=self.params,
        )
        env.seed(self.seed)
        return env

    def solve_instance_with_nn_agent(self, instance: str) -> tuple[list]:
        DEVICE = self.DEVICE
        observation, action_set, _, done, info = self.env.reset(instance)
        solving_time = 0
        presolve_time = info["time"]
        n_nodes = info["n_nodes"]
        while not done:
            with torch.no_grad():
                start_forward_pass = time()
                bipartite_graph = (
                    torch.from_numpy(observation.row_features.astype(np.float32)).to(DEVICE),
                    torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(DEVICE),
                    torch.from_numpy(observation.edge_features.values.astype(np.float32))
                    .view(-1, 1)
                    .to(DEVICE),
                    torch.from_numpy(observation.variable_features.astype(np.float32)).to(DEVICE),
                )
                logits = self.policy(*bipartite_graph)
                logits[[i not in action_set for i in np.arange(len(logits))]] = -np.inf
                assert not torch.isnan(logits).any()
                action = action_set[np.nanargmax(logits[action_set.astype(np.int64)].cpu().numpy())]
                forward_pass = time() - start_forward_pass
                solving_time += forward_pass
                # action = action_set[logits[action_set.astype(np.int64)].argmax()]
                observation, action_set, _, done, info = self.env.step(action)
                solving_time += info["time"]
                if solving_time > self.params["time_limit"]:
                    done = True
                n_nodes += info["n_nodes"]
        dual_gap = self.env.model.as_pyscipopt().getGap()
        return (presolve_time, solving_time), n_nodes, dual_gap

    def solve_instance_with_scip(self, instance: str) -> tuple[float, int]:
        pyscipopt_model = self.pyscipopt_model
        pyscipopt_model.readProblem(instance)
        pyscipopt_model.hideOutput()
        pyscipopt_model.optimize()
        presolve_time = pyscipopt_model.getPresolvingTime()
        solving_time = pyscipopt_model.getSolvingTime()
        n_nodes = pyscipopt_model.getNTotalNodes()
        dual_gap = pyscipopt_model.getGap()
        return (presolve_time, solving_time), n_nodes, dual_gap

    def solve_instance_with_random_policy(self, instance: str) -> tuple[float, int]:
        _, action_set, _, done, info = self.env.reset(instance)
        solving_time = 0
        presolve_time = info["time"]
        n_nodes = info["n_nodes"]
        while not done:
            action = rd.choice(action_set)
            _, action_set, _, done, info = self.env.step(action)
            solving_time += info["time"]
            n_nodes += info["n_nodes"]
            if solving_time > self.params["time_limit"]:
                done = True
        dual_gap = self.env.model.as_pyscipopt().getGap()
        return (presolve_time, solving_time), n_nodes, dual_gap

    def solve_instance_with_dqn_agent(self, instance: str) -> tuple[float, int]:
        obs, action_set, _, done, info = self.env.reset(instance)
        state = extract_state_from_obs(obs, action_set) if obs is not None else None
        solving_time = 0
        presolve_time = info["time"]
        n_nodes = info["n_nodes"]

        with torch.no_grad():
            while not done:
                action = self.dqn_agent.select_greedy_action(state)
                obs, action_set, _, done, info = self.env.step(action)
                state = extract_state_from_obs(obs, action_set) if not done else state
                solving_time += info["time"]
                n_nodes += info["n_nodes"]
                if solving_time > self.params["time_limit"]:
                    done = True
        dual_gap = self.env.model.as_pyscipopt().getGap()
        return (presolve_time, solving_time), n_nodes, dual_gap

    def solve_benchmark(self) -> dict:
        agent_performance = {}
        for difficulty in self.instance_dict:
            logging.info(f"{difficulty}\n")
            evaluation_instances = self.instance_dict[difficulty]
            presolving_time_lst = []
            solving_time_list, n_nodes_list, dual_gap_list= [], [], []
            for instance in evaluation_instances:
                logging.info(f"Instance : {instance}")
                if self.agent == "random":
                    solving_times, n_nodes, dual_gap = self.solve_instance_with_random_policy(instance)
                elif self.agent == "scip":
                    solving_times, n_nodes, dual_gap = self.solve_instance_with_scip(instance)
                elif self.agent in ["dqn_bbmdp", "dqn_retro", "dqn_tMDP"]:
                    solving_times, n_nodes, dual_gap = self.solve_instance_with_dqn_agent(instance)
                elif self.agent in ["il", "pg_tMDP"]:
                    solving_times, n_nodes, dual_gap = self.solve_instance_with_nn_agent(instance)
                presolve_time, solving_time = solving_times
                logging.info(
                    f"Number of nodes : {n_nodes}, Solving time : {solving_time}, Presolving time : {presolve_time}\n"
                )
                presolving_time_lst += [presolve_time]
                solving_time_list += [solving_time]
                n_nodes_list += [n_nodes]
                dual_gap_list += [dual_gap]

            agent_performance = {
                "PresolvingTime": presolving_time_lst,
                "SolvingTime": solving_time_list,
                "NNodes": n_nodes_list,
                "DualGap": dual_gap_list,
                "AverageSolvingTime": gmean(solving_time_list),
                "AveragePresolvingTime": gmean(presolving_time_lst),
                "AverageNNodes": gmean(n_nodes_list),
                "StdSolvingTime": gstd(solving_time_list),
                "StdPresolvingTime": gstd(presolving_time_lst),
                "StdNNodes": gstd(n_nodes_list),
                "AverageDualGap": gmean(dual_gap_list),
            }
        
        logging.info(
            f"AverageNNodes: {gmean(n_nodes_list)}, AverageSolvingtime: {gmean(solving_time_list)}, AveragePresolvingTime: {gmean(presolving_time_lst)}\n"
        )

        return agent_performance

    def run(self):
        DEVICE = self.DEVICE
        logging.info(f"{self.agent}\n")

        if self.agent == "SCIP":
            self.params = {"time_limit": 3600}
            self.random = False

        if self.agent == "SB":
            self.random = False
            self.params = {"branching_rule": "fullstrong", "time_limit": 3600}
            self.pyscipopt_model.setParam("branching/vanillafullstrong/priority", 10001)

        if self.agent in ["il", "pg_tMDP"]:
            network_path = f"../data/trained_models/{self.benchmark}/{self.agent}_best.pkl"
            self.policy = GNNPolicy().to(DEVICE)
            dict_params = torch.load(network_path, map_location=DEVICE)
            self.policy.load_state_dict(dict_params)
            self.params = {"time_limit": 3600}
            if self.agent == "pg_tMDP" and self.benchmark != "MultipleKnapsack":
                self.params["node_selection"] = "dfs"
            self.env: ecole.environment = self.build_env()

        if self.agent == "random":
            self.random = True
            self.params = {
                "time_limit": 3600,
            }
            self.env: ecole.environment = self.build_env()

        if self.agent.startswith("dqn"):
            observation_function = ParsonsonNodeBipartite()
            self.params = {"time_limit": 3600}
            if self.agent == "dqn_bbmdp":
                value_network = GNNParsonsonPolicy(classification=True).to(DEVICE)
                target_network = GNNParsonsonPolicy(classification=True).to(DEVICE)
                self.dqn_agent = TreeDQNAgent(
                    value_network=value_network,
                    target_network=target_network,
                    train=True,
                    classification=True,
                )
            else:
                value_network = GNNParsonsonPolicy(final_invert_activation=True).to(DEVICE)
                target_network = GNNParsonsonPolicy(final_invert_activation=True).to(DEVICE)
                self.dqn_agent = TreeDQNAgent(
                    value_network=value_network,
                    target_network=target_network,
                    train=True,
                    classification=False,
                )

            if self.agent in ["dqn_bbmdp", "dqn_tMDP"] and self.benchmark != "MultipleKnapsack":
                self.params["node_selection"] = "dfs"

            self.env: ecole.environment = self.build_env(observation_function)

            network_path = f"../data/trained_models/{self.benchmark}/{self.agent}_best.pkl"

            dict_params = torch.load(network_path, map_location=DEVICE)
            value_network.load_state_dict(dict_params)

        agent_performance = self.solve_benchmark()

        return agent_performance


if __name__ == "__main__":
    import argparse
    import json

    from experiments.utils import sort_instances

    parser = argparse.ArgumentParser()
    parser.add_argument("--agent", default=None, type=str)
    parser.add_argument("--benchmark", default=None, type=str)
    parser.add_argument("--seed", default=0, type=int)
    args = parser.parse_args()

    logging.basicConfig(
        filename=f"../experiments/logs/{args.benchmark}/evaluation_{args.agent}.log",
        format="%(asctime)s %(message)s",
        level=logging.DEBUG,
    )

    logging.info("\n\n\n")
    logging.info("~" * 89)
    logging.info("~" * 40 + " NEW RUN " + "~" * 40)
    logging.info("~" * 89)
    logging.info("\n\n\n")

    instance_dict = {
        "easy_instances": sort_instances(glob.glob(f"../data/instances/{args.benchmark}/easy/*.lp"))[:100]
    }

    # Uncomment to test performance over medium instances
    # instance_dict = {
    #     "medium_instances": sort_instances(glob.glob(f"../data/instances/{args.benchmark}/medium/*.lp"))[:20]
    # }

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info("Start solving benchmark ...")
    benchmark_runner = BenchmarkRunner(args.agent, instance_dict, args.benchmark, DEVICE, args.seed)
    agent_performance = benchmark_runner.run()

    with open(f"../data/results/{args.benchmark}/{args.agent}.json", "w") as f:
        json.dump(agent_performance, f, indent=4)
