#include <pybind11/pybind11.h>
#include <cstdlib>

namespace py = pybind11;
using namespace std;

#define UP      0
#define DOWN    1
#define LEFT    2
#define RIGHT   3
#define NOOP    4
#define NACTION 5

#define WIDTH   9
#define HEIGHT  6

struct SoccerEnv {
    int state_repr;
    int ax, ay;
    int bx, by;
    int ball, step_count;
    const int max_steps = 100;
    static bool sequential;
    SoccerEnv(int state_repr) : state_repr(state_repr) {
        reset();
    }
    py::tuple reset() {
        ax = WIDTH / 2 - 2 + rand() % 2;
        ay = HEIGHT / 2 - 1 + rand() % 2;
        bx = WIDTH / 2 + 1 + rand() % 2;
        by = HEIGHT / 2 - 1 + rand() % 2;
        ball = rand() % 2;
        step_count = 0;
        return state();
    }
    py::tuple state() {
        // coordinates + ball
        if (state_repr == 0) {
            return py::make_tuple(ax, ay, bx, by, ball);
        } //else if (state_repr == 1) {
        // int s[9+6+9+6+2];
        // for (auto &x : s) x = 0;
        auto s = py::tuple(9+6+9+6+2);
        for (size_t i = 0; i < s.size(); ++i) s[i] = 0;
        s[0 + ax] = 1;
        s[9 + ay] = 1;
        s[15 + bx] = 1;
        s[24 + by] = 1;
        s[30 + ball] = 1;
        return s;
        //}
    }
    py::tuple step(int a, int b) {
        if (step_count == max_steps)
            return py::make_tuple(state(), 0, true, 0);
        step_count++;

        int ax_ = ax, ay_ = ay;
        if (a == 0)
            ay = min(HEIGHT-1, ay+1);
        else if (a == 1)
            ay = max(0, ay-1);
        else if (a == 2)
            // ax = max(1, ax-1); // don't allow A to go back into goal area?
            ax = max(0, ax-1);
        else if (a == 3)
            ax = min(WIDTH-1, ax+1);
        
        int bx_ = bx, by_ = by;
        if (b == 0)
            by = min(HEIGHT-1, by+1);
        else if (b == 1)
            by = max(0, by-1);
        else if (b == 2)
            bx = max(0, bx-1);
        else if (b == 3)
            // bx = min(WIDTH-2, bx+1);
            bx = min(WIDTH-1, bx+1);

        if (ax == bx && ay == by) {
            ball = 1 - ball;
            ax = ax_; ay = ay_;
            bx = bx_; by = by_;
        }

        int reward = 0;
        bool terminal = false;
        if (ball == 0 && ax == WIDTH-1 && ay >= HEIGHT/2-1 && ay <= HEIGHT/2) {
            reward = 1;
            terminal = true;
        } else if (ball == 1 && bx == 0 && by >= HEIGHT/2-1 && by <= HEIGHT/2) {
            reward = -1;
            terminal = true;
        }
        return py::make_tuple(state(), reward, terminal, 0);
    }
};

bool SoccerEnv::sequential = false;

PYBIND11_MODULE(soccer, m) {
    py::class_<SoccerEnv>(m, "SoccerEnv")
        .def(py::init<int>())
        .def("reset", &SoccerEnv::reset)
        .def("step", &SoccerEnv::step)
        .def_readonly_static("sequential", &SoccerEnv::sequential);
}
