import os
import logging

from collections import Counter

import h5py
import numpy as np

import path_config

paths = path_config.get_paths()

hcas_root = paths["acas"]

tabledir = os.path.join(hcas_root, "GenerateTable")
trainingdir = os.path.join(hcas_root, "TrainingData")

os.makedirs(trainingdir, exist_ok=True)
# ident = "oneSpeed_onePsi"
ident = "baseline"
table_filename = "{}.h5".format(ident)

training_data_file_pattern = ident + "_pra{:d}_tau{:02}.h5"
table_fullfilename = os.path.join(tabledir, "Qtables", table_filename)

# Define state space. Make sure this matches up with the constants used to generate the MDP table!
acts = [0, 1, 2, 3, 4]

ranges_fullfilename = os.path.join(tabledir, "ranges.txt")
thetas_fullfilename = os.path.join(tabledir, "thetas.txt")
psis_fullfilename = os.path.join(tabledir, "psis.txt")
intrspeeds_fullfilename = os.path.join(tabledir, "intrspeeds.txt")
ownspeeds_fullfilename = os.path.join(tabledir, "ownspeeds.txt")

ranges = np.loadtxt(ranges_fullfilename, delimiter=",", ndmin=1)
thetas = np.loadtxt(thetas_fullfilename, delimiter=",", ndmin=1)
psis = np.loadtxt(psis_fullfilename, delimiter=",", ndmin=1)
vints = np.loadtxt(intrspeeds_fullfilename, delimiter=",", ndmin=1)
vowns = np.loadtxt(ownspeeds_fullfilename, delimiter=",", ndmin=1)

# psis - np.linspace(-np.pi, +np.pi, 41)

# do_5d = True
# do_5d = False
dim = 3
# dim = 2
taus = np.linspace(0, 60, 61)

if 5 == dim:
    X_raw = np.array([[r * np.cos(t), r * np.sin(t), p, vo, vi] for vi in vints for vo in vowns for p in psis for t in thetas for r in ranges])
elif 3 == dim:
    assert 1 == vowns.size
    assert 1 == vints.size
    X_raw = np.array([[r * np.cos(t), r * np.sin(t), p] for p in psis for t in thetas for r in ranges])
elif 2 == dim:
    assert 1 == vowns.size
    assert 1 == vints.size
    assert 1 == psis.size
    X_raw = np.array([[r * np.cos(t), r * np.sin(t)] for t in thetas for r in ranges])
else:
    raise ValueError("Dim = {} nor configured".format(dim))

# Compute means, ranges, mins and maxes
means = np.mean(X_raw, axis=0)
rnges = np.max(X_raw, axis=0) - np.min(X_raw, axis=0)

# If only one value, then range is 0. Just divide by 1 instead of range
rnges = np.where(rnges == 0.0, 1.0, rnges)
X = (X_raw - means) / rnges

# Compile table values
f = h5py.File(table_fullfilename, "r")
Q = np.array(f["q"])
f.close()
Q = Q.T

ns2 = len(ranges) * len(thetas) * len(psis) * len(vowns) * len(vints) * len(acts)
ns3 = len(ranges) * len(thetas) * len(psis) * len(vowns) * len(vints)

meanQ = np.mean(Q)
rangeQ = np.max(Q) - np.min(Q)
Q = (Q - meanQ) / rangeQ

means = np.concatenate((means, [meanQ]))
rnges = np.concatenate((rnges, [rangeQ]))

if 5 == dim:
    min_inputs = np.array([ranges[0], thetas[0], psis[0], vowns[0], vints[0]])
    max_inputs = np.array([ranges[-1], thetas[-1], psis[-1], vowns[-1], vints[-1]])
elif 3 == dim:
    min_inputs = np.array([-ranges[-1], -ranges[-1], psis[0]])
    max_inputs = np.array([ranges[-1], ranges[-1], psis[-1]])
elif 2 == dim:
    min_inputs = np.array([-ranges[-1], -ranges[-1]])
    max_inputs = np.array([ranges[-1], ranges[-1]])
else:
    raise ValueError("Dim = {} not configured".format(dim))

for tau in [0, 5, 10, 15, 20, 30, 40, 60]:
    # tau = 0
    Qsub = Q[tau * ns2:(tau + 1) * ns2]
    for pra in range(5):
        # pra = 0
        Qsubsub = Qsub[pra * ns3:(pra + 1) * ns3]
        # q_argmax = np.argmax(Qsubsub, axis=1)

        training_data_filename = training_data_file_pattern.format(pra, tau)
        training_data_fullfilename = os.path.join(trainingdir, training_data_filename)
        # filename = training_data_fullfile_pattern.format(pra, tau)
        print("Saving {} values to {}".format(Qsubsub.shape[0], training_data_fullfilename))
        with h5py.File(training_data_fullfilename, "w") as h:
            h.create_dataset("X", data=X)
            h.create_dataset("y", data=Qsubsub)
            h.create_dataset("means", data=means)
            h.create_dataset("ranges", data=rnges)
            h.create_dataset("min_inputs", data=min_inputs)
            h.create_dataset("max_inputs", data=max_inputs)

#
# for k, v in c.items():
#     print(k, v / Q.shape[0])




