#include "opts.h"
#include "pddl/asnets_convert_from_sql.h"
#include "pddl/pddl.h"
#include "print_to_file.h"
#include "pddl/err.h"

static struct {
    pddl_bool_t help;
    pddl_bool_t version;
    int max_mem;
    char* log_out;

    int train_seed;

    int eval_max_steps;
    char* eval_out;
    pddl_bool_t eval_verbose;
} opt;

static enum {
    CMD_TRAIN,
    CMD_EVAL,
    CMD_GEN_FD_ENC,
    CMD_GEN_FD_OSP_ENC,
    CMD_CONVERT_OLD_MODEL,
    CMD_GEN_TRAIN_CONFIG_FILE,
    CMD_GROUND
} cmd;

static pddl_err_t err = PDDL_ERR_INIT;
static FILE* log_out = NULL;

static void
help(const char* argv0, FILE* fout)
{
    fprintf(fout, "version: %s\n", pddl_version);
    fprintf(fout, "Usage: %s COMMAND [OPTIONS] ...\n", argv0);
    fprintf(fout, "\n");
    fprintf(fout, "COMMAND train:\n");
    fprintf(
        fout, "  %s train [OPTIONS] config.toml model-file-prefix\n", argv0);
    fprintf(fout, "\n");
    fprintf(fout, "  Train an ASNets model according to the configuration.\n");
    fprintf(
        fout,
        "  A model after each epoch is written into a file prefixed with "
        "the\n");
    fprintf(fout, "  {model-file-prefix} argument.\n");
    fprintf(fout, "\n");
    fprintf(fout, "\n");
    fprintf(fout, "COMMAND eval:\n");
    fprintf(
        fout,
        "  %s eval [OPTIONS] input-model-file domain.pddl problem.pddl ...\n",
        argv0);
    fprintf(fout, "\n");
    fprintf(
        fout,
        "  Evaluate the ASNets model stored in {input-model-file} and all "
        "specified\n");
    fprintf(fout, "  pddl problems.\n");
    fprintf(fout, "\n");
    fprintf(fout, "\n");
    fprintf(fout, "COMMAND gen-fd-encoding:\n");
    fprintf(
        fout,
        "  %s gen-fd-encoding domain.pddl problem.pddl output.fd\n",
        argv0);
    fprintf(fout, "\n");
    fprintf(
        fout,
        "  Generate FD encoding of the specified task as it is used by "
        "ASNets.\n");
    fprintf(fout, "\n");
    fprintf(fout, "\n");
    fprintf(fout, "COMMAND gen-fd-osp-encoding:\n");
    fprintf(
        fout,
        "  %s gen-fd-osp-encoding domain.pddl problem.pddl output.fd\n",
        argv0);
    fprintf(fout, "\n");
    fprintf(
        fout,
        "  Same as gen-fd-encoding, but it uses a variant of FD format\n");
    fprintf(
        fout, "  for OSP tasks and all goals are specified as soft-goals.\n");
    fprintf(fout, "\n");
    fprintf(fout, "\n");
    fprintf(fout, "COMMAND convert-old-model:\n");
    fprintf(fout, "  %s convert-old-model input-file output-file\n", argv0);
    fprintf(fout, "\n");
    fprintf(
        fout,
        "  Convert old sqlite-based model files to the current format:\n");
    fprintf(fout, "\n");
    fprintf(fout, "\n");
    fprintf(fout, "COMMAND gen-train-config-file:\n");
    fprintf(fout, "  %s gen-train-config-file config.toml\n", argv0);
    fprintf(fout, "\n");
    fprintf(fout, "  Generates a training config file with default values.\n");
    fprintf(fout, "\n");
    fprintf(fout, "\n");
    fprintf(fout, "OPTIONS:\n");
    optsPrint(fout);
}

static void
shiftCmdArgs(int* argc, char* argv[])
{
    for (int i = 2; i < *argc; ++i)
        argv[i - 1] = argv[i];
    *argc -= 1;
}

static int
parseOpts(int* argc, char* argv[])
{
    optsAddFlag("help", 'h', &opt.help, 0, "Print this help.");
    optsAddFlag("version", 0x0, &opt.version, 0, "Print version and exit.");
    optsAddInt("max-mem", 0x0, &opt.max_mem, 0, "Maximum memory in MB if >0.");
    optsAddStr(
        "log-out", 0x0, &opt.log_out, "stdout", "Set output file for logs.");

    optsAddInt(
        "train-seed",
        0x0,
        &opt.train_seed,
        -1,
        "This takes effect only for the 'train' command."
        " It overwrites the random_seed parameter from the configutation "
        "file.");

    optsAddInt(
        "eval-max-steps",
        0x0,
        &opt.eval_max_steps,
        -1,
        "This takes effect only for the 'eval' command."
        " Sets the maximum number of steps a policy can take"
        " (it overwrites 'policy_rollout_limit' option in the input file.");
    optsAddStr(
        "eval-out",
        0x0,
        &opt.eval_out,
        NULL,
        "This takes effect only for the 'eval' command."
        " If set, it specifies a prefix for files where found plans"
        " are written, namely they are written to {prefix}-{task_index}.plan.");
    optsAddFlag(
        "eval-verbose",
        0x0,
        &opt.eval_verbose,
        0,
        "More logs of the 'eval' command.");

    if (*argc <= 1) {
        help(argv[0], stderr);
        return -1;
    }

    if (strcmp(argv[1], "train") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_TRAIN;

    } else if (strcmp(argv[1], "eval") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_EVAL;

    } else if (strcmp(argv[1], "gen-fd-encoding") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_GEN_FD_ENC;

    } else if (strcmp(argv[1], "gen-fd-osp-encoding") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_GEN_FD_OSP_ENC;

    } else if (strcmp(argv[1], "convert-old-model") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_CONVERT_OLD_MODEL;

    } else if (strcmp(argv[1], "gen-train-config-file") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_GEN_TRAIN_CONFIG_FILE;

    } else if (strcmp(argv[1], "ground") == 0) {
        shiftCmdArgs(argc, argv);
        cmd = CMD_GROUND;

    } else {
        fprintf(stderr, "Error: Unknown command %s\n", argv[1]);
        help(argv[0], stderr);
        return -1;
    }

    if (opts(argc, argv) != 0)
        return -1;

    if (opt.help) {
        help(argv[0], stderr);
        return -1;
    }

    if (cmd == CMD_TRAIN) {
        if (*argc != 3) {
            fprintf(stderr, "Error: Command train expects two arguments.");
            help(argv[0], stderr);
            return -1;
        }

    } else if (cmd == CMD_EVAL) {
        if (*argc < 4) {
            fprintf(stderr, "Error: Missing arguments to the command eval.\n");
            help(argv[0], stderr);
            return -1;
        }

    } else if (cmd == CMD_GEN_FD_ENC) {
        if (*argc != 4) {
            fprintf(
                stderr,
                "Error: Command gen-fd-encoding requires 3 arguments.\n");
            help(argv[0], stderr);
            return -1;
        }

    } else if (cmd == CMD_GEN_FD_OSP_ENC) {
        if (*argc != 4) {
            fprintf(
                stderr,
                "Error: Command gen-fd-osp-encoding requires 3 arguments.\n");
            help(argv[0], stderr);
            return -1;
        }

    } else if (cmd == CMD_CONVERT_OLD_MODEL) {
        if (*argc != 3) {
            fprintf(
                stderr,
                "Error: Command convert-old-model takes exactly"
                " two additional arguments, but %d were given.\n",
                *argc - 1);
            help(argv[0], stderr);
            return -1;
        }

    } else if (cmd == CMD_GEN_TRAIN_CONFIG_FILE) {
        if (*argc != 2) {
            fprintf(
                stderr,
                "Error: Command gen-train-config-file takes exactly"
                " one arguments, but %d were given.\n",
                *argc - 1);
            help(argv[0], stderr);
            return -1;
        }
    } else if (cmd == CMD_GROUND) {
        if (*argc < 5 || *argc > 6) {
            fprintf(stderr, "Error: Command ground expects four arguments.");
            help(argv[0], stderr);
            return -1;
        }
    }

    return 0;
}

static int
train(int argc, char* argv[])
{
    pddl_asnets_config_t cfg;
    if (pddlASNetsConfigInitFromFile(&cfg, argv[1], &err) != 0)
        PDDL_TRACE_RET(&err, -1);

    if (opt.train_seed > 0)
        cfg.random_seed = opt.train_seed;

    cfg.save_model_prefix = argv[2];

    pddl_asnets_t* asnets = pddlASNetsNew(&cfg, &err);
    if (asnets == NULL)
        PDDL_TRACE_RET(&err, -1);

    return pddlASNetsTrain(asnets, &err);
}

struct eval_stats {
    pddl_bool_t solved;
    int plan_length;
    int osp_goal_size;
    int osp_msgs_size;
};

static int
evaluate(int argc, char* argv[])
{
    pddl_asnets_t* asnets = pddlASNetsNewLoad(argv[1], argv[2], &err);
    if (asnets == NULL)
        PDDL_TRACE_RET(&err, -1);

    const pddl_asnets_config_t* cfg = pddlASNetsGetConfig(asnets);
    const pddl_asnets_lifted_task_t* lt = pddlASNetsGetLiftedTask(asnets);

    int policy_rollout_limit = cfg->policy_rollout_limit;
    if (opt.eval_max_steps > 0)
        policy_rollout_limit = opt.eval_max_steps;

    int num_probs = argc - 3;
    struct eval_stats stats[num_probs];

    int num_solved = 0;
    for (int pi = 0; pi < num_probs; ++pi) {
        PDDL_CTX(&err, "Task %d", pi);
        pddl_asnets_ground_task_t gt;
        int st =
            pddlASNetsGroundTaskInit(&gt, lt, argv[2], argv[3 + pi], cfg, &err);
        if (st != 0) {
            pddlASNetsDel(asnets);
            PDDL_CTXEND(&err);
            PDDL_TRACE_RET(&err, -1);
        }

        PDDL_LOG(&err, "Rollout of policy ...");
        pddl_asnets_policy_rollout_t rollout;
        pddlASNetsPolicyRolloutInit(&rollout, &gt);
        if (opt.eval_verbose) {
            pddlASNetsPolicyRolloutVerbose(
                asnets, &rollout, &gt, policy_rollout_limit, &err);
        } else {
            pddlASNetsPolicyRollout(
                asnets, &rollout, &gt, policy_rollout_limit, &err);
        }
        PDDL_LOG(
            &err, "Found plan: %s", (rollout.found_plan ? "true" : "false"));
        PDDL_LOG(&err, "Num states: %d", rollout.states.num_states);
        PDDL_LOG(&err, "Num ops: %d", pddlIArrSize(&rollout.ops));
        PDDL_LOG(&err, "Plan size: %d", pddlIArrSize(&rollout.plan));
        PDDL_LOG(
            &err, "OSP reached goal size: %d", rollout.osp_reached_goal_size);

        stats[pi].solved = rollout.found_plan;
        stats[pi].plan_length = -1;
        if (rollout.found_plan)
            stats[pi].plan_length = pddlIArrSize(&rollout.plan);
        stats[pi].osp_goal_size = rollout.osp_reached_goal_size;
        stats[pi].osp_msgs_size = gt.osp_msgs_size_for_init;

        if (stats[pi].solved)
            num_solved += 1;

        if (rollout.found_plan && opt.eval_out != NULL) {
            char fn[512];
            snprintf(fn, 511, "%s-%06d.plan", opt.eval_out, pi);
            FILE* fout = fopen(fn, "w");
            if (fout != NULL) {
                int op_id;
                PDDL_IARR_FOR_EACH(&rollout.plan, op_id)
                {
                    fprintf(fout, "(%s)\n", gt.fdr.op.op[op_id]->name);
                }
                fclose(fout);

            } else {
                pddlASNetsPolicyRolloutFree(&rollout);
                pddlASNetsGroundTaskFree(&gt);
                pddlASNetsDel(asnets);
                PDDL_CTXEND(&err);
                PDDL_ERR_RET(&err, -1, "Could not open file %s", fn);
            }
        }

        pddlASNetsPolicyRolloutFree(&rollout);
        pddlASNetsGroundTaskFree(&gt);
        PDDL_CTXEND(&err);
    }

    for (int pi = 0; pi < num_probs; ++pi) {
        if (cfg->osp_all_soft_goals) {
            PDDL_LOG(
                &err,
                "Task %d solved: %s, length: %d, goal size: %d/%d :: %s",
                pi,
                (stats[pi].solved ? "true" : "false"),
                stats[pi].plan_length,
                stats[pi].osp_goal_size,
                stats[pi].osp_msgs_size,
                argv[pi + 3]);
        } else {
            PDDL_LOG(
                &err,
                "Task %d solved: %s, length: %d :: %s",
                pi,
                (stats[pi].solved ? "true" : "false"),
                stats[pi].plan_length,
                argv[pi + 3]);
        }
    }

    PDDL_LOG(&err, "Solved: %d / %d", num_solved, num_probs);
    PDDL_LOG(&err, "Success rate: %.2f", (float)num_solved / (float)num_probs);

    pddlASNetsDel(asnets);
    return 0;
}

static int
genFDEnc(int argc, char* argv[], int osp)
{
    pddl_asnets_lifted_task_t lt;
    if (pddlASNetsLiftedTaskInit(&lt, argv[1], &err) != 0) {
        pddlErrPrint(&err, 1, stderr);
        return -1;
    }

    pddl_asnets_config_t cfg;
    pddlASNetsConfigInit(&cfg);

    pddl_asnets_ground_task_t gt;
    if (pddlASNetsGroundTaskInit(&gt, &lt, argv[1], argv[2], &cfg, &err) != 0) {
        pddlASNetsLiftedTaskFree(&lt);
        pddlErrPrint(&err, 1, stderr);
        return -1;
    }

    pddl_fdr_write_config_t wcfg = PDDL_FDR_WRITE_CONFIG_INIT;
    wcfg.filename = argv[3];
    if (osp)
        wcfg.osp_all_soft_goals = pddl_true;
    wcfg.fd = pddl_true;
    pddlFDRWrite(&gt.fdr, &wcfg);

    pddlASNetsGroundTaskFree(&gt);
    pddlASNetsLiftedTaskFree(&lt);
    return 0;
}

static int
convertOldModel(int argc, char* argv[])
{
    return pddlASNetsConvertFromSql(argv[1], argv[2], &err);
}

static int
genTrainConfigFile(int argc, char* argv[])
{
    FILE* fout = fopen(argv[1], "w");
    if (fout == NULL)
        PDDL_ERR_RET(&err, -1, "Could not open file %s", argv[1]);

    pddl_asnets_config_t cfg;
    pddlASNetsConfigInit(&cfg);
    pddlASNetsConfigWrite(&cfg, fout);
    fclose(fout);
    return 0;
}

static int
ground(int argc, char* argv[])
{
    pddl_asnets_t* asnets = pddlASNetsNewLoad(argv[1], argv[2], &err);
    if (asnets == NULL)
        PDDL_TRACE_RET(&err, -1);

    const pddl_asnets_config_t* cfg = pddlASNetsGetConfig(asnets);
    const pddl_asnets_lifted_task_t* lt = pddlASNetsGetLiftedTask(asnets);

    pddl_asnets_ground_task_t gt;
    int st = pddlASNetsGroundTaskInit(&gt, lt, argv[2], argv[3], cfg, &err);
    if (st != 0) {
        pddlASNetsDel(asnets);
        PDDL_TRACE_RET(&err, -1);
    }

    pddl_ground_asnets_conf_t gConf;
    pddlGroundASNetsConfInit(&gConf, gt.fact_size, gt.op_size);

    if (argc == 6) {
        pddlGroundASNetsConfLoad(&gConf, argv[4], &gt);
    } else {
        pddl_fdr_t* fdr = &gt.fdr;

        gConf.num_facts = 0 ;
        gConf.num_operators = fdr->op.op_size;
        gConf.num_variables = fdr->var.var_size;
        gConf.num_labels = fdr->op.op_size;

        for (int var_id = 0; var_id < fdr->var.var_size; ++var_id) {
            const pddl_fdr_var_t* var = fdr->var.var + var_id;
            for (int val_id = 0; val_id < var->val_size; ++val_id) {
                const pddl_fdr_val_t* val = var->val + val_id;
                int fact_id = val->strips_id;
                if (fact_id < 0 || fact_id >= gt.fact_size) {
                    continue;
                }
                gConf.variable[fact_id] = var_id;
                gConf.value[fact_id] = val_id;
                ++gConf.num_facts ;
            }
        }
        gConf.num_variables = fdr->var.var_size;
        gConf.num_labels = gt.op_size;
        for (int fdr_op_id = 0; fdr_op_id < fdr->op.op_size; ++fdr_op_id) {
            const pddl_fdr_op_t* op = fdr->op.op[fdr_op_id];
            PDDL_PANIC_IF(op->strips_id < 0 || op->strips_id >= fdr->op.op_size, "op_id out of bounds");
            gConf.label[op->strips_id] = fdr_op_id;
        }

        pddl_fdr_write_config_t write_cfg;
        write_cfg.filename = "output.sas";
        write_cfg.fout = NULL;
        write_cfg.fd = pddl_true;
        write_cfg.use_fd_fact_names = pddl_true;
        write_cfg.mgroups = NULL;
        write_cfg.encode_op_ids = pddl_false;
        write_cfg.osp_all_soft_goals = pddl_false;
        pddlFDRWrite(fdr, &write_cfg);

        FILE* interface = fopen("asnets.jani2nnet", "w");
        fprintf(interface, "{\n");
        fprintf(interface, "  \"elements\": [],\n");
        fprintf(interface, "  \"file\": \"\",\n");
        fprintf(interface, "  \"filter\": false,\n");
        fprintf(interface, "  \"input\": [\n");
        for (int var_id = 0; var_id < fdr->var.var_size; ++var_id) {
            fprintf(interface, "    { \"automaton\": null, \"name\": \"var%d\" }%s", var_id, var_id + 1 < fdr->var.var_size ? ",\n" : "\n");
        }
        fprintf(interface, "  ],\n");
        fprintf(interface, "  \"output\": [\n");
        for (int op_id = 0; op_id < fdr->op.op_size; ++op_id) {
            const pddl_fdr_op_t* op = gt.fdr.op.op[op_id];
            fprintf(interface, "    \"%s\"%s", op->name, op_id + 1 < gt.op_size ? ",\n" : "\n");
        }
        fprintf(interface, "  ]\n");
        fprintf(interface, "}\n");
        fclose(interface);
    }

    pddl_ground_asnets_t gAsnets;
    pddlASNetsPolicyGround(&gAsnets, asnets, &gt, &gConf, &err);
    pddlDumpGroundASNetsModel(&gAsnets, argc == 6 ? argv[5] : argv[4], &err);

    // for (int i = 0; i < gt.strips.fact.fact_size; ++i) {
    //     printf("fact#%d = %s\n", i, gt.strips.fact.fact[i]->name);
    // }
    //
    // for (int i = 0; i < gt.strips.op.op_size; ++i) {
    //     printf("op#%d = %s\n", i, gt.strips.op.op[i]->name);
    // }

    pddlGroundASNetsConfFree(&gConf);
    // pddlGroundASNetsFree(&gAsnets);
    // pddlASNetsGroundTaskFree(&gt);
    // pddlASNetsDel(asnets);

    return 0;
}

int
main(int argc, char* argv[])
{
    pddl_timer_t timer;
    pddlTimerStart(&timer);

    if (parseOpts(&argc, argv) != 0) {
        pddlErrPrint(&err, 1, stderr);
        return -1;
    }

    if (opt.log_out != NULL) {
        log_out = openFile(opt.log_out);
        pddlErrLogEnable(&err, log_out);
    }

    if (opt.max_mem > 0) {
        struct rlimit mem_limit;
        mem_limit.rlim_cur = mem_limit.rlim_max = opt.max_mem * 1024UL * 1024UL;
        setrlimit(RLIMIT_AS, &mem_limit);
    }

    PDDL_LOG(&err, "Version: %s", pddl_version);

    int ret = 0;
    switch (cmd) {
    case CMD_TRAIN: ret = train(argc, argv); break;

    case CMD_EVAL: ret = evaluate(argc, argv); break;

    case CMD_GEN_FD_ENC: ret = genFDEnc(argc, argv, 0); break;

    case CMD_GEN_FD_OSP_ENC: ret = genFDEnc(argc, argv, 1); break;

    case CMD_CONVERT_OLD_MODEL: ret = convertOldModel(argc, argv); break;

    case CMD_GEN_TRAIN_CONFIG_FILE: ret = genTrainConfigFile(argc, argv); break;

    case CMD_GROUND: ret = ground(argc, argv); break;
    }

    if (ret != 0)
        pddlErrPrint(&err, 1, stderr);

    if (log_out != NULL)
        closeFile(log_out);
    return ret;
}
