import os
import logging

from collections import Counter

import h5py
import tqdm
import numpy as np
import matplotlib.pyplot as plt

import path_config

np.set_printoptions(linewidth=1000)

paths = path_config.get_paths()

hcas_root = paths["acas"]

tabledir = os.path.join(hcas_root, "GenerateTable")
trainingdir = os.path.join(hcas_root, "TrainingData")

os.makedirs(trainingdir, exist_ok=True)
# ident = "oneSpeed_onePsi"
ident = "baseline"
table_filename = "{}.h5".format(ident)

training_data_file_pattern = ident + "_pra{:d}_tau{:02}.h5"
table_fullfilename = os.path.join(tabledir, "Qtables", table_filename)

# Define state space. Make sure this matches up with the constants used to generate the MDP table!
acts = [0, 1, 2, 3, 4]

ranges_fullfilename = os.path.join(tabledir, "ranges.txt")
thetas_fullfilename = os.path.join(tabledir, "thetas.txt")
psis_fullfilename = os.path.join(tabledir, "psis.txt")
intrspeeds_fullfilename = os.path.join(tabledir, "intrspeeds.txt")
ownspeeds_fullfilename = os.path.join(tabledir, "ownspeeds.txt")


if __name__ == "__main__":
    ranges = np.loadtxt(ranges_fullfilename, delimiter=",", ndmin=1)
    thetas = np.loadtxt(thetas_fullfilename, delimiter=",", ndmin=1)
    psis = np.loadtxt(psis_fullfilename, delimiter=",", ndmin=1)
    vints = np.loadtxt(intrspeeds_fullfilename, delimiter=",", ndmin=1)
    vowns = np.loadtxt(ownspeeds_fullfilename, delimiter=",", ndmin=1)

    # do_5d = True
    # do_5d = False
    dim = 3
    # dim = 2
    taus = np.linspace(0, 60, 61)

    if 5 == dim:
        X_raw = np.array([[r * np.cos(t), r * np.sin(t), p, vo, vi] for vi in vints for vo in vowns for p in psis for t in thetas for r in ranges])
    elif 3 == dim:
        assert 1 == vowns.size
        assert 1 == vints.size
        X_raw = np.array([[r * np.cos(t), r * np.sin(t), p] for p in psis for t in thetas for r in ranges])
    elif 2 == dim:
        assert 1 == vowns.size
        assert 1 == vints.size
        assert 1 == psis.size
        X_raw = np.array([[r * np.cos(t), r * np.sin(t)] for t in thetas for r in ranges])
    else:
        raise ValueError("Dim = {} nor configured".format(dim))

    ns2 = len(ranges) * len(thetas) * len(psis) * len(vowns) * len(vints) * len(acts)
    ns3 = len(ranges) * len(thetas) * len(psis) * len(vowns) * len(vints)

    # Compile table values
    f = h5py.File(table_fullfilename, "r")
    Q = np.array(f["q"])
    f.close()
    Q = Q.T

    x_polar = np.array([[r, t, p] for p in psis for t in thetas for r in ranges])

    tau = 0
    Qsub = Q[tau * ns2:(tau + 1) * ns2]
    pra = 0
    Qsubsub = Qsub[pra * ns3:(pra + 1) * ns3]
    num_rows = Qsubsub.shape[0]

    x_rounded = np.round(X_raw, decimals=4)
    q_rounded = np.round(Qsubsub, decimals=4)
    q_rounded_sorted = np.sort(q_rounded, axis=1)
    unique_sorted_rows = np.unique(q_rounded_sorted, axis=0)

    is_pos_psi = np.logical_and(x_rounded[:, -1] > 0,
                                x_rounded[:, -1] < np.pi - 1e-7)

    num_unique_sorted_rows = unique_sorted_rows.shape[0]
    corresp_rows = [None] * num_unique_sorted_rows
    corresp_rows_count = np.full((num_unique_sorted_rows,), np.nan)

    for idx in tqdm.tqdm(range(num_unique_sorted_rows), leave=True):
        # idx = 14563
        u_row = unique_sorted_rows[idx, :]

        is_u_row = np.all(u_row == q_rounded_sorted, axis=1)

        r = np.logical_and(is_u_row, is_pos_psi)
        # assert not (53101 in np.where(r)[0])
        corresp_rows[idx] = np.where(r)[0]
        corresp_rows_count[idx] = np.sum(r)

    inds = corresp_rows_count > 1
    c = corresp_rows_count[inds]

    tt = np.argwhere(inds)

    for idx in tt.flatten():
        # idx = 15717
        # idx = 15734
        # idx = 14378
        # idx = 14907
        # idx = 15409

        rows = corresp_rows[idx]

        assert rows.size >= 2
        # if 2 == rows.size:
        #     continue

        x_rows = x_rounded[rows, :]

        if False:
            x_polar[rows, :]

            x_polar[rows, :][0, :] - x_polar[rows, :][1, :]
            print(idx)
            print(rows)
            print(x_polar[rows, :])
            print(q_rounded[rows, :])

        q_rows = q_rounded[rows, :]

        uxm = np.unique(np.abs(x_rows))
        is_pos_in = np.in1d(x_rows[:, 0], +1 * uxm)
        is_neg_in = np.in1d(x_rows[:, 0], -1 * uxm)
        is_nonzero = 0 != x_rows[:, 0]

        is_both_in = np.logical_and(is_pos_in, is_neg_in)
        is_both_in_and_zero = np.logical_and(is_both_in, is_nonzero)

        try:
            assert np.all(q_rows[0, :] == q_rows[1, :])
        except Exception as e:
            if 2 == q_rows.shape[0]:
                print(idx)
                print(x_polar[rows, :][0, :] - x_polar[rows, :][1, :])
            # print(rows)
            # print(x_polar[rows, :])
            # print(q_rounded[rows, :])


        # set(x_rows[:, 0])
        # set(x_rows[:, 1])

        # print(np.max(x_rows[:, 0]) - np.min(x_rows[:, 0]))
        # print(np.max(x_rows[:, 1]) - np.min(x_rows[:, 1]))
        #
        # is_both_in = np.logical_and(np.logical_and(is_pos_in, is_neg_in), is_nonzero)
        #
        # if 0 == x_rows[is_both_in, :].shape[0]:
        #     continue
        # x_rows[is_both_in, :]
        # print(idx)

        # qry = 3399.156195964032
        # qry = 21791.434019675893
        # qry = 4017.18459523022
        # qry = 2781.1277966978446
        qry = 39597.98621665497
        rrr = np.abs(np.abs(x_rows[:, 0]) - qry) < .1

        # rrrr = [84, 137]
        print(x_rows[rrr, :])
        print(q_rows[rrr, :])

    # is_nonzero_psi = (0 != X_raw[:, -1])
    #
    # input_reflector = np.diag(np.array([+1, -1, -1]))
    # output_reflector = np.array(([[1, 0, 0, 0, 0],
    #                               [0, 0, 1, 0, 0],
    #                               [0, 1, 0, 0, 0],
    #                               [0, 0, 0, 0, 1],
    #                               [0, 0, 0, 1, 0]]))
    #
    # is_nonneg_psi = X_raw[:, 2] >= 0
    # is_not_coincident = np.logical_and(X_raw[:, 0] != 0, X_raw[:, 1] != 0)
    # is_examine_row = np.logical_and(is_nonneg_psi, is_not_coincident)
    # examine_rows = np.argwhere(is_examine_row).flatten().tolist()
    # num_examine_rows = len(examine_rows)
    #
    # for idx, row in tqdm.tqdm(enumerate(examine_rows), total=num_examine_rows):
    #     row_x = X_raw[row, :]
    #     row_q = Qsubsub[row, :]
    #
    #     ref_row_x = input_reflector @ row_x
    #     is_eq_row = np.all(ref_row_x == X_raw, axis=1)
    #     assert 1 == np.sum(is_eq_row)
    #
    #     ref_row_idx = np.argmax(is_eq_row)
    #     ref_row_q = Qsubsub[ref_row_idx, :]
    #
    #     _ = output_reflector @ row_q
    #     np.testing.assert_allclose(_, ref_row_q, atol=1e-5)
    #
    # r1 = 52480
    # r2 = 0
    # print(Qsubsub[r1, :])
    # print(Qsubsub[r2, :])
    #
    # print(X_raw[53537, :])
    # print(X_raw[225, :])
    #
    # print(Qsubsub[53537, :])
    # print(Qsubsub[225, :])
    #
    # ref_row_x
    # # array([-11.34986886, -22.27510891, -3.14159])
    #
    # # num_trials = 10000
    # #
    # # for _ in range(num_trials):
    # #
    # #     row_idx = np.random.randint(X_raw.shape[0])
    # #
    # #     row_x = X_raw[row_idx, :]
    # #     row_q = Qsubsub[row_idx, :]
    # #
    # #     ref_row_x = input_reflector @ row_x
    # #     ref_row_idx = np.argmax(np.all(ref_row_x == X_raw, axis=1))
    # #     ref_row_q = Qsubsub[ref_row_idx, :]
    # #     #
    # #     # print(row_q @ output_reflector)
    #
    # q_subsub_rounded = np.round(Qsubsub, decimals=4)
    # q_subsub_rounded_sorted = np.sort(q_subsub_rounded, axis=1)
    # unique_sorted_rows = np.unique(q_subsub_rounded_sorted, axis=0)
    #
    # num_unique_sorted_rows = unique_sorted_rows.shape[0]
    # corresp_rows = [None] * num_unique_sorted_rows
    # corresp_rows_count = np.full((num_unique_sorted_rows,), np.nan)
    #
    # for idx in tqdm.tqdm(range(num_unique_sorted_rows), leave=True):
    #     # idx = 0
    #     u_row = unique_sorted_rows[idx, :]
    #     r = np.all(u_row == q_subsub_rounded_sorted, axis=1)
    #     r_nonzero = np.logical_and(r, is_nonzero_psi)
    #
    #     corresp_rows[idx] = np.where(r_nonzero)
    #     corresp_rows_count[idx] = np.sum(r_nonzero)
    #
    # # nonzero_psi = np.where(is_nonzero_psi)[0]
    # #     # plt.plot(corresp_rows_count)
    # ge1 = corresp_rows_count > 1
    # lemax = corresp_rows_count < np.max(corresp_rows_count)
    # inds = np.logical_and(ge1, lemax)
    # c = corresp_rows_count[inds]
    #
    # tt = np.argwhere(inds)
    # # _ = np.argmax(np.logical_and(inds, np.arange(num_unique_sorted_rows) > 103))
    # # idx = 15947
    # # idx = 103
    # # idx = 105
    #
    # for idx in tt.flatten():
    #     # idx = 142
    #     # idx = 15947
    #     rows = corresp_rows[idx][0]
    #     print(idx)
    #     assert rows.size >= 2
    #     # if False:
    #     #     print(idx)
    #     #     print(X_raw[rows, :])
    #     #     print(q_subsub_rounded[rows, :])
    #     if 2 == rows.size:
    #
    #         x_raw_rows = X_raw[rows, :]
    #         q_raw_rows = q_subsub_rounded[rows, :]
    #         both_same = np.all(q_raw_rows[0, :] == q_raw_rows[1, :])
    #         if both_same:
    #             continue
    #         input_raw = x_raw_rows[0, :]
    #         input_ref = input_reflector @ input_raw
    #         np.testing.assert_allclose(input_ref, x_raw_rows[1, :])
    #
    #         output_raw = q_raw_rows[0, :]
    #         output_ref = output_reflector @ output_raw
    #         np.testing.assert_allclose(output_ref, q_raw_rows[1, :])
    #
    # # valsets = [sorted(q_subsub_rounded[idx, :]) for idx in range(num_rows)]
    # # len(valsets)#
    # # len(set(valsets))
    # #
    # # if False:
    # #     unique_rows = np.unique(q_subsub_rounded, axis=0)
    # #     num_unique_rows = unique_rows.shape[0]
    # #     corresp_rows = [None] * num_unique_rows
    # #     corresp_rows_count = np.full((num_unique_rows,), np.nan)
    # #
    # #     for idx in tqdm.tqdm(range(num_unique_rows)):
    # #         # idx = 0
    # #         # idx = 3
    # #         u_row = unique_rows[idx, :]
    # #         r = np.all(u_row == q_subsub_rounded, axis=1)
    # #         corresp_rows[idx] = np.where(r)
    # #         corresp_rows_count[idx] = np.sum(r)
    # #
    # rows = corresp_rows[3][0]
    #
    #
    #
    # u_psi = np.unique(X_raw[:, 2])
    #
    # x = +5.0e+01
    # # y = +5.0e+01
    # y = +5.0e+01
    #
    # # psi = 1.5708
    # # theta = 1.5708
    #
    # # psi = -.785398
    # # theta = .785398
    #
    # # psi = -2.35619
    # # psi = 2.35619
    # # theta = 2.35619
    #
    # psi = 2.35619
    # theta = 2.35619
    #
    # r = ranges[10]
    #
    # r_rows = x_polar[:, 0] == r
    # theta_rows = x_polar[:, 1] == theta
    # psi_rows = x_polar[:, 2] == psi
    #
    # rows = np.logical_and(np.logical_and(r_rows, theta_rows), psi_rows)
    #
    # tmp1 = x_polar[rows, :]
    # tmp2 = Qsubsub[rows, :]
    #
    # print(tmp1)
    # print(tmp2)
    #
    # print(tmp2[:+3, :])
    # print(tmp2[-3:, :])
    #
    # theta_rows = (x_polar[:, 1] == theta)
    # np.sum(theta_rows)
    #
    # x_col = 0
    # y_col = 1
    # psi_col = 2
    #
    # x_rows = np.abs(X_raw[:, x_col] - x) < .01
    # y_rows = np.abs(X_raw[:, y_col] - y) < .01
    # psi_rows = X_raw[:, psi_col] == psi
    #
    # if False:
    #     plt.plot(np.sort(X_raw[:, x_col]))
    #     plt.plot(np.sort(X_raw[:, y_col]))
    #
    #
    # print(np.sum(x_rows))
    # print(np.sum(y_rows))
    # print(np.sum(psi_rows))
    #
    # rows = np.logical_and(np.logical_and(x_rows, y_rows), psi_rows)
    # # np.mean(psi_rows)
    #
    # x_psi = X_raw[psi_rows, :]
    # q_psi = Qsubsub[psi_rows, :]
    #
    # x_psi[-3:, :]
    # q_psi[-3:, :]
    #
    # print(q_psi)
    #
    #
    #
