import os
import numpy as np
from minigrid_basics.examples.visualizer import Visualizer
from os.path import join
from collections import deque
from itertools import islice
from minigrid_basics.examples.ROD_cycle import RODCycle


# testing imports
import gym
import gin
from minigrid_basics.reward_envs import maxent_mon_minigrid
from minigrid_basics.custom_wrappers import maxent_mdp_wrapper
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import matplotlib.pyplot as plt
import subprocess
import glob

class ROD_SR_Q(RODCycle):
    def __init__(self, env, n_steps=100, p_option=0.05, dataset_size=None, learn_rep_iteration=10, representation_step_size=0.1,
                 gamma=0.99, num_options=None, eigenoption_step_size=0.1, plot=True):
        super().__init__(env, n_steps, p_option, dataset_size, learn_rep_iteration, representation_step_size, gamma, num_options, eigenoption_step_size, plot)

        # keep track of performance of Q-learning
        self.Q_performance = []


    def learn_Q_policy(self,):
        """
        Q-learning with batch data.
        Construct option
        """
        Q = np.zeros((self.env.num_states, self.env.num_actions)) - 10000  # pessimistic initialization
        Q[self.env.terminal_idx[0]] = 0     # terminal state Q-value is 0

        while True:   
            max_delta = 0
            for (s, a, r, ns) in reversed(self.dataset):
                

                if s == self.env.terminal_idx[0]:
                    continue

                # update Q
                delta = r + self.gamma * Q[ns].max() - Q[s, a]
                Q[s, a] +=  delta

                # keep track of max change
                max_delta = max(np.abs(delta), max_delta)

            if max_delta < 1e-5:
                break

                
        pi = np.argmax(Q, axis=1)
        return pi

    def evaluate_Q_policy(self,):
        pi = self.learn_Q_policy()

        episode_return = 0
        # environment and policy both deterministic, just one episode is enough
        s = self.env.reset()
        for n in range(self.n_steps):

            a = pi[s['state']]
            ns, r, done, d = self.env.step(a)
            episode_return += r

            if done:
                # print("Reached goal", d['terminated'], d['truncated'])
                break
            else:
                s = ns
        print(episode_return)
        self.Q_performance.append(episode_return)

    def rod_cycle(self, n_iterations=100):
        """
        Perform ROD cycle until 
        1) all states are visited, or
        2) number of iterations finished
        """
        print("-----------------------")
        print("ROD Cycle Start")
        print("-----------------------")

        

        for i in range(n_iterations): 

            print(f"  [Iteration: {i + 1}]", end="  ")
            
            # collect samples
            self.collect_samples()

            # extra step: evaluate Q-learning on collected dataset
            self.evaluate_Q_policy()

            # update representation
            self.learn_representation()

            # compute eigenvector
            e0 = self.compute_eigenvector()

            # compute eigenoption
            option = self.compute_eigenoption(e0)

            if self.plot:
                self.visualize_cycle(i, e0, option)

            # append option to set of options
            self.options.append(option)

            print(f"State Visit %: {self.state_visit_percentage[-1]:.2f}")

            # # terminate if visited all states
            # if self.state_visit_percentage[-1] == 1.:
            #     break


        print("-----------------------")
        print("ROD Cycle End")
        print("-----------------------")


        return self.cumulative_reward, self.state_visit_percentage


if __name__ == "__main__":
    

    env_name = "gridroom_2"

    gin.parse_config_file(os.path.join(maxent_mon_minigrid.GIN_FILES_PREFIX, f"{env_name}.gin"))
    env_id = maxent_mon_minigrid.register_environment()

    np.random.seed(1)

    env = gym.make(env_id, seed=42, no_goal=False)
    env = maxent_mdp_wrapper.MDPWrapper(env, )

    rodc = ROD_SR_Q(env, learn_rep_iteration=10, dataset_size=100, p_option=0.1, eigenoption_step_size=0.1, num_options=1)

    rewards, visit_percentage = rodc.rod_cycle(n_iterations=120)

    print(rodc.Q_performance)
