import argparse
import csv
from collections import defaultdict
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import pyscipopt as scip
from tqdm import tqdm
from utils import GraphicalModel, ScipUtils

DATASET_DIR = "dataset"


class UserDefinedRootDecision(scip.Branchrule):

    def __init__(self, model, root_var, root_value):
        self.model = model
        self.root_var = root_var
        self.root_value = root_value

    def branchexeclp(self, allowaddcons):
        depth = self.model.getDepth()

        # Check if the node is a root node
        if depth == 0:
            # Get variables for branching
            variable = ScipUtils.get_variable_by_name(
                self.model, f"x_{self.root_var}_{self.root_value}"
            )
            down_child, eq_child, up_child = self.model.branchVarVal(variable, 0.5)
            return {"result": scip.SCIP_RESULT.BRANCHED}
        else:
            return {"result": scip.SCIP_RESULT.DIDNOTRUN}


def solve_scip_program(program):
    program.optimize()
    statistics = ScipUtils.get_statistics(program)
    return statistics


class SampleDataGenerator:

    def __init__(self, pgm, evid_vars, query_vars, save_dir):
        self.pgm = pgm
        self.evid_vars = evid_vars
        self.query_vars = query_vars
        self.save_dir = save_dir

    def generate_dataset(self, *args):
        sample_count, sample_line = args[0]
        # First token of the sample is the sample id which will be used to name
        # the csv file
        split_sample = sample_line.split(",")
        sample_id, sample = split_sample[0], list(map(int, split_sample[1:]))
        evidence = {var: sample[var] for var in self.evid_vars}
        with open(self.save_dir / f"sample{sample_id}.csv", "w") as f:
            writer = csv.writer(f, delimiter=",")
            curr_program, program_variables = ScipUtils.get_mpe_program(
                self.pgm, evidence
            )
            statistics = solve_scip_program(curr_program)
            writer.writerow(
                ["query_variable", "query_variable_value"] + list(statistics.keys())
            )
            writer.writerow(["", ""] + list(statistics.values()))

            for root_var in tqdm(
                self.query_vars,
                position=sample_count,
                desc=f"Query Variables for sample{sample_id}",
            ):
                for root_value in range(2):
                    curr_program, program_variables = ScipUtils.get_mpe_program(
                        self.pgm, evidence
                    )
                    # Add constraint on current variable with current value
                    curr_program.addCons(program_variables[root_var][root_value] == 1)
                    statistics = solve_scip_program(curr_program)
                    writer.writerow([root_var, root_value] + list(statistics.values()))
        return


def initialize(branch_rule_dataset_dir, sample_lines):
    # Get done processing samples
    done_processing = []
    for fi in branch_rule_dataset_dir.iterdir():
        name = fi.name
        if name.endswith(".csv"):
            sample_id = name.split(".")[0].split("sample")[1]
            done_processing.append(sample_id)

    sample_ids = [line.split(",")[0] for line in sample_lines]
    sample_id_to_idx = {sample_id: idx for idx, sample_id in enumerate(sample_ids)}
    remaining_samples = list(set(sample_ids) - set(done_processing))
    return sample_ids, sample_id_to_idx, done_processing, remaining_samples


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate Imitation Learning from Strong branching scores"
    )
    parser.add_argument("--network-name", required=True, dest="network_name")
    parser.add_argument("--num-workers", dest="num_workers", default=4, type=int)
    parser.add_argument("--query-ratio", dest="qr", default=0.5, type=float)
    parser.add_argument("--num-samples", dest="num_samples", default=1000, type=int)
    parser.add_argument("--append", dest="append", action="store_true", default=False)
    args = parser.parse_args()

    network_dir = Path(Path(DATASET_DIR) / args.network_name)
    branch_rule_dataset_dir = Path(
        network_dir / "variable-value-pair-dataset" / f"query_ratio{args.qr}"
    )
    branch_rule_dataset_dir.mkdir(parents=True, exist_ok=True)

    with open(network_dir / "samples.txt") as f:
        sample_lines = f.readlines()

    sample_ids, sample_id_to_idx, done_processing, remaining_samples = initialize(
        branch_rule_dataset_dir, sample_lines
    )

    pgm = GraphicalModel.from_uai_file(network_dir / "pgm-model.uai")
    all_vars = list(pgm.graph.nodes)
    if args.append:
        with open(branch_rule_dataset_dir / "pgm-model.evid") as f:
            lines = filter(None, f.readlines())
        evid_vars = list(map(int, lines))
        query_vars = list(set(all_vars) - set(evid_vars))
    else:
        evid_vars = sorted(
            np.random.choice(
                np.array(all_vars),
                size=int((1 - args.qr) * len(all_vars)),
                replace=False,
            ).tolist()
        )
        if Path(branch_rule_dataset_dir / "pgm-model.evid").exists():
            proceed = input("You are about to overwrite the evid file, Proceed?[Y/n]")
            if proceed != "Y":
                exit(0)

        with open(branch_rule_dataset_dir / "pgm-model.evid", "w") as f:
            for var in evid_vars:
                f.write(str(var) + "\n")
        query_vars = list(set(all_vars) - set(evid_vars))

    print(
        f"{len(remaining_samples)} remaining to process for network {args.network_name}"
    )
    if args.num_samples < len(remaining_samples):
        ids_to_process = np.random.choice(
            remaining_samples, size=args.num_samples, replace=False
        )
    else:
        ids_to_process = remaining_samples

    samples_to_process = [
        sample_lines[sample_id_to_idx[sample_id]] for sample_id in ids_to_process
    ]
    print(f"Processing: {len(samples_to_process)} samples")
    with Pool(processes=args.num_workers) as pool:
        generator = SampleDataGenerator(
            pgm, evid_vars, query_vars, branch_rule_dataset_dir
        )
        pool.map(generator.generate_dataset, enumerate(samples_to_process))
