from argparse import Namespace
from collections import defaultdict
from threading import Thread, Lock, Event
import time
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch

from mas_sat.agent.decide.base import BaseDecideAgent
from mas_sat.dataset.graph import GraphDataset
from mas_sat.env.kissat.base import KissatBaseEnv
from mas_sat.env.kissat.data_structure import LP_c_uint
from mas_sat.graph.base import BaseGraph
from mas_sat.learn.base import BaseLearner
from mas_sat.utils.meters import AverageMeter, SumMeter, ProgressMeter

class Engine(object):
    """
    An engine that runs solving process on the given dataset

    The engine has two solving modes:
    - solver: rely on the model as heuristic for the solver
        - the solver mode can also be assisted by a standalone model,
          if the model provides solution in the graph
    - standalone: rely on the model to directly solve the instance

    The engine also has two running modes:
    - train: run an episode/batch, record data for training
    - evaluate: run the whole dataset, evaluate the performance

    The engine will be called by `engine.running_mode(solving_mode)`
    """
    def __init__(
        self,
        dataset: Dataset,
        graph: BaseGraph,
        env: KissatBaseEnv,
        model: torch.nn.Module,
        agent: BaseDecideAgent,
        device: torch.device,
        writer: torch.utils.tensorboard.SummaryWriter,
        args: Namespace,
    ) -> None:
        # components
        self.dataset = dataset
        self.graph = graph
        self.env = env
        self.model = model
        self.agent = agent
        self.device = device
        self.writer = writer

        # components for standalone mode
        self.pt_dataset = GraphDataset(self.env, self.graph, args.dim)
        self.dataloader = DataLoader(
            self.pt_dataset,
            batch_size=args.batch_size,
            collate_fn=BaseGraph.batch_graph,
            shuffle=True
        )
        self.dataloader_iter = iter(self.dataloader)

        # hyper-parameters
        self.dim = args.dim
        self.recurrent = args.recurrent
        self.budget = args.budget
        self.async_mode = args.async_mode
        self.influence = (args.agent == "model_influence")

        # counters
        ## for solver mode
        self.episode_counter = 0
        self.sat_counter = 0
        self.unsat_counter = 0
        self.unknown_counter = 0
        self.solver_solved = 0
        ## for standalone mode
        self.batch_counter = 0
        self.sample_counter = 0
        self.standalone_solved = 0

    # basic get/set functions
    def state_dict(self) -> dict:
        return {
            "episode_counter": self.episode_counter,
            "sat_counter": self.sat_counter,
            "unsat_counter": self.unsat_counter,
            "unknown_counter": self.unknown_counter,
            "solver_solved": self.solver_solved,
            "batch_counter": self.batch_counter,
            "sample_counter": self.sample_counter,
            "standalone_solved": self.standalone_solved,
        }

    def load_state_dict(self, state_dict: dict) -> None:
        self.episode_counter = state_dict["episode_counter"]
        self.sat_counter = state_dict["sat_counter"]
        self.unsat_counter = state_dict["unsat_counter"]
        self.unknown_counter = state_dict["unknown_counter"]
        self.solver_solved = state_dict["solver_solved"]
        self.batch_counter = state_dict["batch_counter"]
        self.sample_counter = state_dict["sample_counter"]
        self.standalone_solved = state_dict["standalone_solved"]

    # six main internal methods
    # 1. run_solver
    # 2. run_standalone
    # 3. train_solver
    # 4. train_standalone
    # 5. evaluate_solver
    # 6. evaluate_standalone
    def run_solver(self, learner: BaseLearner = None):
        """
        Run one episode in solver mode
        If a learner is provided, will add transitions to the learner
        """
        observation, info = self.env.reset()
        original_graph = self.graph.from_observation(observation, self.dim, original=True).to(self.device)

        # do a dummy call and go to first call of heuristic
        dummy_action = self.env.unwrapped.get_dummy_action()
        observation, reward, terminated, truncated, info = self.env.step(dummy_action)

        accumulative_reward = 0
        model_step = 0
        n_step = 0
        solution = None
        while not terminated and not truncated:
            learned_graph = self.graph.from_observation(observation, self.dim).to(self.device)
            combined_graph = self.graph.combine_graph(original_graph, learned_graph)
            if combined_graph.is_trivial():
                action_idx, action = self.agent.get_random_action(combined_graph)
                ret_dict = {}
            elif solution is not None:
                action_idx, action = self.agent.get_action_from_solution(combined_graph, solution)
                ret_dict = {}
            else:
                eps = learner.get_eps() if learner is not None else 0.0
                action_idx, action, ret_dict = self.agent.act(combined_graph, eps)
                model_step += ret_dict["model_step"]
                n_step += 1
                if solution is None and ret_dict.get("solution", None) is not None:
                    solution = ret_dict["solution"]
                if self.budget > 0 and n_step >= self.budget:
                    self.env.unwrapped.out_of_budget()
                if self.recurrent:
                    original_graph = self.graph.update_graph(original_graph, ret_dict["updated_graph"])
                if self.influence and "heuristic" in ret_dict:
                    variable_heuristic = ret_dict["heuristic"].softmax(-1).reshape(-1, 2).sum(dim=1)
                    variable_indices = combined_graph.get_candidate_indices().reshape(-1, 2)[:,0]
                    variable_indices = torch.divide(variable_indices, 2, rounding_mode="floor")
                    rank = torch.argsort(variable_heuristic, descending=True)
                    indices_ranked = variable_indices[rank]
                    num = len(indices_ranked)
                    indices_ranked = indices_ranked.cpu().numpy().astype(np.uint32)
                    indices_ranked = indices_ranked.ctypes.data_as(LP_c_uint)
                    self.env.unwrapped.influence(indices_ranked, num)
            observation, reward, terminated, truncated, info = self.env.step(action)
            accumulative_reward += reward
            if learner is not None:
                learner.add_transition(
                    combined_graph, action_idx, reward, terminated or truncated,
                    original_graph, ret_dict
                )

        metadata = self.env.close()
        metadata["reward"] = accumulative_reward
        metadata["model_step"] = model_step
        metadata["solver_solved"] = (solution is not None)
        return metadata

    def async_agent(self, original_observation, shared_dict, locks, events):
        # initialize
        original_graph = self.graph.from_observation(original_observation, self.dim, original=True).to(self.device)
        heuristic = shared_dict["heuristic"].clone()

        while not events["solver_end"].is_set():
            # update observation
            if events["new_observation"].is_set():
                with locks["observation"]:
                    observation = {k: v.copy() for k, v in shared_dict["observation"].items()}
                learned_graph = self.graph.from_observation(observation, self.dim).to(self.device)
                combined_graph = self.graph.combine_graph(original_graph, learned_graph)
                events["new_observation"].clear()
            # call agent
            if combined_graph.is_trivial():
                events["agent_end"].set()
                break
            else:
                action_idx, action, ret_dict = self.agent.act(combined_graph)
                if self.recurrent:
                    original_graph = self.graph.update_graph(original_graph, ret_dict["updated_graph"])
                shared_dict["model_step"] += ret_dict["model_step"]
            # update heuristic
            if hasattr(original_graph, "solution"):
                solution = observation["literal_values"].flatten()
                solution[original_graph["solution"]] = 1
                with locks["heuristic"]:
                    shared_dict["heuristic"] = solution
                events["agent_end"].set()
                shared_dict["solver_solved"] = True
                break
            else:
                if ret_dict["heuristic"] is not None:
                    heuristic[combined_graph.get_candidate_indices().flatten().cpu()] = ret_dict["heuristic"].flatten().cpu()
                    with locks["heuristic"]:
                        shared_dict["heuristic"] = heuristic.clone()
                    events["new_heuristic"].set()

    def run_solver_async(self):
        """
        Run one episode in solver mode asynchronously
        """
        # initialize
        original_observation, info = self.env.reset()
        n_literal = len(original_observation["literal_values"])
        heuristic = torch.zeros(n_literal)
        indices = torch.arange(n_literal)

        # do a dummy call and go to first call of heuristic
        dummy_action = self.env.unwrapped.get_dummy_action()
        observation, reward, terminated, truncated, info = self.env.step(dummy_action)
        candidate = observation["literal_candidates"].flatten()
        candidate_indices = indices[candidate]
        if self.influence:
            num = n_literal // 2
            variable_indices = torch.arange(num)

        # start async agent, wait for the first heuristic
        shared_dict = {
            "observation": {k: v.copy() for k, v in observation.items()},
            "heuristic": heuristic.clone(),
            "model_step": 0,
            "solver_solved": False
        }
        locks = {
            "observation": Lock(),
            "heuristic": Lock()
        }
        events = {
            "new_observation": Event(),
            "new_heuristic": Event(),
            "agent_end": Event(),
            "solver_end": Event()
        }
        thread = Thread(
            target=self.async_agent,
            args=(original_observation, shared_dict, locks, events)
        )
        thread.start()
        events["new_observation"].set()
        events["new_heuristic"].wait()

        # statistics
        accumulative_reward = 0
        solver_update = 0
        agent_update = 0

        # start running
        while not terminated and not truncated:
            # update heuristic
            if events["new_heuristic"].is_set():
                with locks["heuristic"]:
                    heuristic = shared_dict["heuristic"].clone()
                events["new_heuristic"].clear()
                agent_update += 1
            # step the solver with the heuristic
            if self.influence:
                variable_heuristic = heuristic.softmax(-1).reshape(-1, 2).sum(dim=1)
                rank = torch.argsort(variable_heuristic, descending=True)
                indices_ranked = variable_indices[rank]
                indices_ranked = indices_ranked.cpu().numpy().astype(np.uint32)
                indices_ranked = indices_ranked.ctypes.data_as(LP_c_uint)
                self.env.unwrapped.influence(indices_ranked, num)
            candidate_heuristic = heuristic[candidate]
            action_idx = torch.argmax(candidate_heuristic)
            action = candidate_indices[action_idx]
            observation, reward, terminated, truncated, info = self.env.step(action)
            accumulative_reward += reward
            # update observation
            if not events["agent_end"].is_set() and observation:
                candidate = observation["literal_candidates"].flatten()
                candidate_indices = indices[candidate]
                with locks["observation"]:
                    for k, v in observation.items():
                        shared_dict["observation"][k] = v.copy()
                events["new_observation"].set()
                solver_update += 1

        # cleanup the thread
        events["solver_end"].set()
        thread.join()
        
        # return metadata
        metadata = self.env.close()
        metadata["solver_update"] = solver_update
        metadata["agent_update"] = agent_update
        metadata["reward"] = accumulative_reward
        metadata["model_step"] = shared_dict["model_step"]
        metadata["solver_solved"] = shared_dict["solver_solved"]
        return metadata

    def run_standalone(self, data: Batch, learner: BaseLearner = None):
        """
        Run a batch of data in standalone mode
        If a learner is provided, will add transitions to the learner
        """
        data = data.to(self.device)
        ret_dict = self.model(data.to(self.device))
        if learner is not None:
            learner.add_transition(None, None, None, None, None, ret_dict)
        assignment_loss = ret_dict["scores"].mean()
        model_step = ret_dict["model_step"]
        standalone_solved = ret_dict["solved"]
        return assignment_loss, model_step, standalone_solved
        
    def train_solver(self, learner: BaseLearner):
        """
        Run one episode in train solver mode
        """
        metadata = self.run_solver(learner)
        self.episode_counter += 1
        if metadata["result"] is None:
            self.unknown_counter += 1
        elif metadata["result"]:
            self.sat_counter += 1
        else:
            self.unsat_counter += 1
        if metadata["solver_solved"]:
            self.solver_solved += 1
        self.writer.add_scalar("reward/train", metadata["reward"], self.episode_counter)
        self.writer.add_scalar("propagation/train", metadata["propagations"], self.episode_counter)
        self.writer.add_scalar("solver_step/train", metadata["decisions"], self.episode_counter)
        self.writer.add_scalar("solver_model_step/train", metadata["model_step"], self.episode_counter)
        self.writer.add_scalar("sat/train", self.sat_counter, self.episode_counter)
        self.writer.add_scalar("unsat/train", self.unsat_counter, self.episode_counter)
        self.writer.add_scalar("unknown/train", self.unknown_counter, self.episode_counter)
        self.writer.add_scalar("solver_solved/train", self.solver_solved, self.episode_counter)

    def evaluate_solver(self, counter: int, record: bool = False):
        """
        Evaluate on the whole dataset in solver mode
        """
        reward_meter = AverageMeter("reward", "{:.4f}")
        prop_meter = AverageMeter("prop", "{:.4e}")
        solver_step_meter = AverageMeter("solver_step", "{:.2f}")
        model_step_meter = AverageMeter("model_step", "{:.2f}")
        sat_meter = SumMeter("sat", "{:4d}")
        unsat_meter = SumMeter("unsat", "{:4d}")
        unknown_meter = SumMeter("unknown", "{:4d}")
        solved_meter = SumMeter("solved", "{:4d}")
        meters = [
            reward_meter, prop_meter, solver_step_meter, model_step_meter,
            sat_meter, unsat_meter, unknown_meter, solved_meter
        ]
        if self.async_mode:
            solver_update_meter = AverageMeter("solver_update", "{:.4e}")
            agent_update_meter = AverageMeter("agent_update", "{:.4e}")
            meters += [solver_update_meter, agent_update_meter]
        title = "Eval Solver [{}]".format(counter)
        progress_meter = ProgressMeter(self.dataset.len(), meters, title)
        tqdm.write(progress_meter.title_str())
        for i_instance in tqdm(range(self.dataset.len())):
            if self.async_mode:
                metadata = self.run_solver_async()
            else:
                metadata = self.run_solver()
            if metadata["result"] is None:
                unknown_meter.update(1)
            else:
                reward_meter.update(metadata["reward"])
                prop_meter.update(metadata["propagations"])
                solver_step_meter.update(metadata["decisions"])
                model_step_meter.update(metadata["model_step"])
                if self.async_mode:
                    solver_update_meter.update(metadata["solver_update"])
                    agent_update_meter.update(metadata["agent_update"])
                if metadata["result"] is True:
                    sat_meter.update(1)
                else:
                    unsat_meter.update(1)
            if metadata["solver_solved"]:
                solved_meter.update(1)
        if record:
            self.writer.add_scalar("reward/evaluate", reward_meter.result, counter)
            self.writer.add_scalar("propagation/evaluate", prop_meter.result, counter)
            self.writer.add_scalar("solver_step/evaluate", solver_step_meter.result, counter)
            self.writer.add_scalar("solver_model_step/evaluate", model_step_meter.result, counter)
            self.writer.add_scalar("sat/evaluate", sat_meter.result, counter)
            self.writer.add_scalar("unsat/evaluate", unsat_meter.result, counter)
            self.writer.add_scalar("unknown/evaluate", unknown_meter.result, counter)
            self.writer.add_scalar("solver_solved/evaluate", solved_meter.result, counter)
        tqdm.write(progress_meter.summary_str())

    def train_standalone(self, learner: BaseLearner):
        """
        Run one batch of PT data in train standalone mode
        """
        try:
            data = next(self.dataloader_iter)
        except:
            self.dataloader_iter = iter(self.dataloader)
            data = next(self.dataloader_iter)
        assignment_loss, model_step, standalone_solved = self.run_standalone(data, learner)
        self.batch_counter += 1
        self.sample_counter += len(data)
        self.standalone_solved += standalone_solved

        self.writer.add_scalar("assignment_loss/train", assignment_loss.item(), self.batch_counter)
        self.writer.add_scalar("standalone_model_step/train", model_step, self.batch_counter)
        self.writer.add_scalar("standalone_solved/train", self.standalone_solved, self.sample_counter)

    def evaluate_standalone(self, counter: int, record: bool = False):
        """
        Evaluate the whole PT dataset in standalone mode
        """
        assignment_loss_meter = AverageMeter("assignment_loss", "{:.4e}")
        model_step_meter = AverageMeter("model_step", "{:.2f}")
        solved_meter = SumMeter("solved", "{:d}")
        meters = [assignment_loss_meter, model_step_meter, solved_meter]
        title = "Eval Standalone [{}]".format(counter)
        progress_meter = ProgressMeter(len(self.dataloader), meters, title)
        tqdm.write(progress_meter.title_str())
        evaluate_info = []
        for data in tqdm(self.dataloader):
            assignment_loss, model_step, solved = self.run_standalone(data)
            evaluate_info.append([model_step, solved])
            assignment_loss_meter.update(assignment_loss, len(data))
            model_step_meter.update(model_step, len(data))
            solved_meter.update(solved)
        tqdm.write(progress_meter.summary_str())
        if record:
            self.writer.add_scalar("assignment_loss/evaluate", assignment_loss_meter.result, counter)
            self.writer.add_scalar("standalone_model_step/evaluate", model_step_meter.result, counter)
            self.writer.add_scalar("standalone_solved/evaluate", solved_meter.result, counter)
        return evaluate_info

    # Below are the main interfaces: train and evaluate
    def train(self, mode: str, learner: BaseLearner) -> None:
        """
        Run one episode/batch in training mode
        """
        if mode == "solver":
            self.train_solver(learner)
        elif mode == "standalone":
            self.train_standalone(learner)
        else:
            raise ValueError("Unrecognized mode: {}.".format(mode))

    def evaluate(self, mode: str, counter: int, record: bool = False) -> list:
        with torch.no_grad():
            if mode == "solver":
                self.evaluate_solver(counter, record)
            elif mode == "standalone":
                self.evaluate_standalone(counter, record)
            else:
                raise ValueError("Unrecognized mode: {}.".format(mode))
