import numpy as np
import os
import matplotlib.pyplot as plt
from time import sleep
from env import Environment
import wandb

class HQLearning:
    def __init__(self, env, num_of_cols=9):
        self.num_of_cols = num_of_cols
        self.env = env
        self.Q = {}
        self.model = {}
        self.env.accessible_states.pop(8)
        for s in self.env.accessible_states:
            self.Q[s] = []
            for a in range(4):
                self.Q[s] += [np.random.random()]
        self.forwards_model = self.build_forwards_model()
        self.backwards_model = self.build_backwards_model()

    def train(self, episode_nums, env, alpha, gamma, eval_epochs, render=False):
        total_reward = 0
        episode_num = 0
        running_average = []
        count = 0
        ep_len = []
        s = env.reset()
        while episode_num < episode_nums:
            count += 1
            a = self.sample_action(s)
            p_s = s
            p_s_equivalent, a_equivalent = self.equivalent_state(p_s, a)
            s, r, done = env.step(a)
            next_states, next_actions = self.get_all_possible_lenses(s)
            best_value = max([self.Q[next_states[i]][next_actions[i]] for i in range(4)])
            total_reward += r
            print("action", a_equivalent)
            self.Q[p_s_equivalent][a_equivalent] += alpha * (
                r + (gamma * best_value) - self.Q[p_s_equivalent][a_equivalent]
            )
            if done:
                s = env.reset()
                env.clear()
                episode_num += 1
                running_average.append(total_reward)
                ep_len.append(count)
                count = 0
                print("episode reward", total_reward)
                total_reward = 0
        return running_average, ep_len

    def get_all_possible_lenses(self, s):
        equivalent_states = []
        equivalent_actions = []
        for a in range(4):
            equivalent_state, equivalent_action = self.equivalent_state(s, a)
            equivalent_states.append(equivalent_state)
            equivalent_actions.append(equivalent_action)
        return equivalent_states, equivalent_actions

    def build_forwards_model(self):
        forwards_model = np.zeros((4, 54))
        for a_starting_state in self.env.accessible_states:
            for an_action in range(4):
                self.env.player_pos = a_starting_state
                end_state, _, _ = self.env.step(an_action)
                forwards_model[an_action][a_starting_state] = int(end_state)
        return forwards_model

    def build_backwards_model(self):
        backwards_model = np.zeros((4, 54))
        for a_starting_state in self.env.accessible_states:
            for an_action in range(4):
                self.env.player_pos = a_starting_state
                state, _, _ = self.env.step(an_action)
                end_states = self.forwards_model[an_action]
                p_s_pred = np.where(end_states == state)[0][0]
                backwards_model[an_action][p_s_pred] = a_starting_state
        return backwards_model

    def equivalent_state(self, state, action):
        final_same_state = self.forwards_model[action][state]
        equivalent_state = self.backwards_model[0][int(final_same_state)]
        if equivalent_state == 0:
            return state, action
        return equivalent_state, 0

    def lensed_argmax(self, state):
        values = []
        for action in range(4):
            equivalent_state, equivalent_action = self.equivalent_state(state, action)
            values.append(self.Q[equivalent_state][equivalent_action])
        best_action = np.argmax(values)
        return best_action

    def sample_action(self, s):
        if np.random.random() < 0.1:
            return np.random.choice([0, 1, 2, 3])
        return self.lensed_argmax(s)

    def print_policy(self):
        best_actions = {}
        for s in self.env.accessible_states:
            a = np.argmax(self.Q[s])
            if a == 1:
                a = "^"
            if a == 0:
                a = "<"
            if a == 2:
                a = ">"
            if a == 3:
                a = "v"
            best_actions[s] = a
        self.env.clear()
        print("----------------BEST POLICY----------------")
        row1 = ["-", "-", "-", "-", "-", "-", "-", "X", "G"]
        row2 = ["-", "-", "X", "-", "-", "-", "-", "X", "-"]
        row3 = ["S", "-", "X", "-", "-", "-", "-", "X", "-"]
        row4 = ["-", "-", "X", "-", "-", "-", "-", "-", "-"]
        row5 = ["-", "-", "-", "-", "-", "X", "-", "-", "-"]
        row6 = ["-", "-", "-", "-", "-", "-", "-", "-", "-"]
        rows = [row1, row2, row3, row4, row5, row6]
        for s in self.env.accessible_states:
            row_num = s // 9
            col_num = s % 9
            rows[row_num][col_num] = best_actions[s]
        rows[0][8] = "G"
        print(rows[0])
        print(rows[1])
        print(rows[2])
        print(rows[3])
        print(rows[4])
        print(rows[5])
        print("-------------------------------------------")


def main():
    wandb.init(project="HDYNA", config={"Homomorphic": True})
    env = Environment()
    agent = HQLearning(env)
    running_average, ep_len = agent.train(50, env, 0.1, 0.95, 5)
    for i in range(len(ep_len)):
        if i == 0:
            continue
        wandb.log({"reward": running_average[i]})
        wandb.log({"episode length": ep_len[i]})
    agent.print_policy()


if __name__ == "__main__":
    main()
