import numpy as np
import pickle
# from .mazerunner import MazeRunnerEnv
from copy import deepcopy
import os
import torch
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
# from PIL import Image


import os
# from gym import utils
# from gym.envs.robotics import rotations, robot_env, fetch_env
# import gym.envs.robotics.utils as robot_utils
from gymnasium import utils
from gymnasium_robotics.utils import rotations
from gymnasium_robotics.envs import robot_env
from gymnasium_robotics.envs.fetch import fetch_env
import gymnasium_robotics.utils.mujoco_py_utils as robot_utils
from gymnasium.spaces import Discrete

# from raDT.baselines.envs.reach_obstacle import FetchReachObstacleEnv

from pyboolnet.file_exchange import bnet2primes, primes2bnet
from pyboolnet.prime_implicants import find_constants, create_variables
from pyboolnet.repository import get_primes

from pyboolnet.repository import get_primes
from pyboolnet.state_transition_graphs import create_stg_image
from pyboolnet.state_transition_graphs import energy, random_walk, add_style_path, add_style_anonymous, stg2image, \
    primes2stg
from pyboolnet.state_transition_graphs import sccgraph2image
from pyboolnet.state_transition_graphs import stg2sccgraph, stg2condensationgraph, best_first_reachability

from pyboolnet.attractors import compute_attractors_tarjan, compute_attractors, find_attractor_state_by_randomwalk_and_ctl


import numpy as np
import pandas as pd


class CardiogenesisEnv:
    def __init__(self, fixed_interval=1, avoids_readable = [], fixed_goal_readable=None, fixed_start_readable=None, initial_state_list_readable=None):
        grn = """
            Bmp2,   (!canWnt & exogen_Bmp2_II)
            canWnt,    exogen_canWnt_II
            Dkk1,   (Mesp1 | (canWnt & !exogen_Bmp2_II))
            Fgf8,   (!Mesp1 & (Foxc1_2 | Tbx1))
            Foxc1_2,    (canWnt & exogen_canWnt_II)
            GATAs,    (Nkx2_5 | Mesp1 | Tbx5)
            Isl1,    (Tbx1 | Mesp1 | Fgf8 | (canWnt & exogen_canWnt_II))
            Mesp1,    (canWnt & !exogen_Bmp2_II)
            Nkx2_5,    ((Isl1 & GATAs) | Tbx1 | (Mesp1 & Dkk1) | (Bmp2 & GATAs) | Tbx5)
            Tbx1,    Foxc1_2
            Tbx5,    (!(Tbx1 | canWnt) & (Nkx2_5 | Tbx5 | Mesp1) & !(Dkk1 & !(Mesp1 | Tbx5)))
            exogen_Bmp2_I,    exogen_Bmp2_I
            exogen_Bmp2_II,    exogen_Bmp2_I
            exogen_canWnt_I,    exogen_canWnt_I
            exogen_canWnt_II,    exogen_canWnt_I
            """
        bnet = bnet2primes(grn)
        self.bnet = bnet

        if not initial_state_list_readable:
            attrs_readable = [a["state"]['str'] for a in compute_attractors(bnet, "asynchronous")["attractors"]]
        else:
            attrs_readable = initial_state_list_readable

        self.initial_state_readable_list = attrs_readable
        self.initial_state_matrix = np.array([np.array([int(x) for x in readable]) for readable in attrs_readable])

        self.fixed_interval = fixed_interval

        self.avoid_states_readable_list = avoids_readable
        self.avoid_states_matrix = np.array([np.array([int(x) for x in readable]) for readable in avoids_readable])
        
        if fixed_goal_readable:
            self.fixed_goal_readable = fixed_goal_readable
            self.fixed_goal = np.array([int(x) for x in fixed_goal_readable])
        else:
            self.fixed_goal_readable = fixed_goal_readable
            self.fixed_goal = None

        if fixed_start_readable:
            self.fixed_start_readable = fixed_start_readable
            self.fixed_start = np.array([int(x) for x in fixed_start_readable])
        else:
            self.fixed_start_readable = fixed_start_readable
            self.fixed_start = None

        self.action_space = Discrete(15)

        self.reset()
        
    
    def _sample_goal(self):
        if self.fixed_goal_readable:
            self.goal = self.fixed_goal
            self.goal_readable = self.fixed_goal_readable
        else:
            random_i = np.random.choice(np.arange(len(self.initial_state_readable_list))) # choose goal from attractors
            self.goal = self.initial_state_matrix[random_i, :]
            self.goal_readable = "".join([str(x) for x in self.goal])
        return self.goal, self.goal_readable
    
    def _sample_start(self):
        if self.fixed_start_readable:
            self.start = self.fixed_start
            self.start_readable = self.fixed_start_readable
        else:
            random_i = np.random.choice(np.arange(len(self.initial_state_readable_list))) # choose start from attractors]
            self.start = self.initial_state_matrix[random_i, :]
            self.start_readable = "".join([str(x) for x in self.start])
        return self.start, self.start_readable

    # def _sample_avoids(self):
    #     choices = [np.array([0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1]),
    #                 np.array([0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1]),
    #                 np.array([1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0]),
    #                 np.array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0]),
    #                 np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    #                 np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]),
    #                 np.array([0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0]),
    #                 np.array([0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]),
    #                 np.array([0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1]),
    #                 np.array([1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1]),
    #                 np.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1]),
    #                 np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    #                 np.array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1]),
    #                 np.array([0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0]),
    #                 np.array([0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1]),
    #                 np.array([0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0]),
    #                 np.array([0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1]),
    #                 np.array([0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0]),
    #                 np.array([0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1]),
    #                 np.array([0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1])]
        

    def _get_obs(self):
        return {"observation": self.current_state.copy(), 
                "achieved_goal": self.current_state.copy(),
                "desired_goal": self.goal.copy()
                }

    
    def step(self, action_int):
        self.ep += 1
        # action_int = np.where(action == 1)[0][0]
        current_state_readable = self.current_state_readable

        perturbed_state_readable = current_state_readable[:action_int] + str(1 - int(current_state_readable[action_int])) + current_state_readable[action_int + 1:]
        path = random_walk(self.bnet, "asynchronous", initial_state=perturbed_state_readable, length=self.fixed_interval)
        path_readable = np.array(["".join([str(x) for x in p.values()]) for p in path])

        new_state_readable = path_readable[-1]
        self.current_state_readable = new_state_readable
        self.current_state = np.array([int(x) for x in current_state_readable])

        reward = 0
        cost = int(self.current_state_readable in self.avoid_states_readable_list)
        done = False
        info = {"is_success": int(self.current_state_readable == self.goal_readable)}

        return self._get_obs(), reward, cost, done, info

    
    def reset(self):
        self.ep = 0
        goal, goal_readable = self._sample_goal()
        start, start_readable = self._sample_start()
        while goal_readable == start_readable:
            goal, goal_readable = self._sample_goal()
            start, start_readable = self._sample_start()

        self.current_state = start
        self.current_state_readable = start_readable

        reward = 0
        cost = int(self.current_state_readable in self.avoid_states_readable_list)
        done = False
        infos = {"is_success": int(self.current_state_readable == self.goal_readable)}

        return self._get_obs(), infos


###########################################

# '''
# For online evaluation: MazeRunner env with a fixed maze and goal across episodes
# '''
# class MazeRunnerEvalEnv(MazeRunnerEnv):
#     def __init__(self, maze_dim, min_num_goals, max_num_goals, maze, goal_pos):
#         self.env_name = 'mazerunner'
#         self.maze = deepcopy(maze)
#         self.goal_positions = deepcopy(goal_pos)
#         self.prompt_dim = 2 # goal: (x,y)
#         super().__init__(maze_dim, min_num_goals, max_num_goals)
#         self.discrete_action = True if hasattr(self.action_space, 'n') else False
    
#     def reset(self, *args, **kwargs):
#         self.start = (self.maze_dim - 2, self.maze_dim // 2)
#         empty_locations = [x for x in zip(*np.where(self.maze == 0))]
#         empty_locations.remove(self.start)
#         self.active_goal_idx = 0
#         self.pos = self.start
#         self._enforce_reset = False
#         self._plotting = False
#         self._goal_render_texts = [None for _ in range(len(self.goal_positions))]
#         self.episode_visited_goals = [] # save prompt goals in episode experience, for online adaptation
#         return self._get_obs()

#     def step(self, act):
#         obs, rew, terminated, truncated, info = super().step(act)
#         self.episode_visited_goals.append(np.asarray(obs[0:2])) # save prompt goals in episode experience, for online adaptation
#         return obs, rew, terminated, info

#     def render_with_prompt(self, prompt, return_rgb=False):
#         if not self._plotting:
#             self.start_plotting()
#             plt.ion()
#             self._plotting = True

#         plt.tight_layout()
#         background = np.ones((self.maze_dim, self.maze_dim, 3), dtype=np.uint8)
#         maze_img = (
#             background * abs(np.expand_dims(self.maze, -1) - 1) * 255
#         )  # zero out (white) where there is a valid path

#         prompt_color = [240, 3, 252]
#         for p in prompt:
#             maze_img[p[0], p[1], :] = prompt_color

#         maze_img[self.pos[0], self.pos[1], :] = [110, 110, 110]  # grey
#         plt.imshow(maze_img)

#         for i, p in enumerate(prompt):
#             plt.text(
#                 p[1], p[0], str(i), ha="center", va="center"
#             )

#         plt.draw()
#         plt.pause(0.1)

#         if return_rgb:
#             canvas = FigureCanvasAgg(plt.gcf())
#             canvas.draw()
#             w, h = canvas.get_width_height()
#             buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8)
#             buf.shape = (w,h,4)
#             buf = np.roll(buf, 3, axis=2)
#             image = Image.frombytes("RGBA", (w,h), buf.tostring())
#             image = np.asarray(image)[:,:,:3]
#             return image

#     def get_task_prompt(self):
#         return [np.array(p)/self.maze_dim for p in self.goal_positions]

# make envs for eval with the tasks in trajs
def get_env_list(trajs, device, max_ep_len=50, **kwargs):
    infos, env_list = [], []
    for traj in trajs:
        # maze = traj['maze']
        env = CardiogenesisEnv(**kwargs)
        env.prompt_dim = 15
        info = {'max_ep_len': max_ep_len, 'state_dim': traj["observations"].shape[1], #env.observation_space["observation"].shape[0], # env.observation_space.shape[0],
                'act_dim': env.action_space.n, 'device': device, 'prompt_dim': env.prompt_dim, 'discrete_action': True}
        infos.append(info)
        env_list.append(env)
    return infos, env_list

# load training dataset, and envs of some training/test tasks
def get_train_test_dataset_envs(dataset_path, device, max_ep_len=50, n_train_env=50, n_test_env=50, **kwargs):
    fn = os.path.basename(dataset_path)
    #max_ep_len = int(fn.split('-')[-2][1:])
    #print(max_ep_len)

    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)
    test_trajectories_list = trajectories[0:n_test_env] # unseen tasks to eval
    trajectories_list = trajectories[n_test_env:] # offline data for training
    val_trajectories_list = trajectories_list[0:n_train_env] # seen tasks to eval
    
    info, env_list = get_env_list(val_trajectories_list, device, max_ep_len, **kwargs)
    test_info, test_env_list = get_env_list(test_trajectories_list, device, max_ep_len, **kwargs)
    #print(test_info, test_env_list, info, env_list, len(trajectories_list))
    return info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list


# given a trajectory, sample a goal-sequence as prompt
def get_prompt(trajectory, max_prompt_length=5, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        prompt_length = np.random.randint(1,max_prompt_length+1)
    
    goal_timesteps = []
    if use_optimal_prompt and 'optimal_prompts' in trajectory: # use the saved optimal prompt
        if prompt_length > trajectory['optimal_prompts'].shape[0]-1: # if prompt_length exceeds optimal prompts
            prompt_length = trajectory['optimal_prompts'].shape[0]-1
        if prompt_length>1:
            goal_range = np.arange(1, trajectory['optimal_prompts'].shape[0]-1)
            goal_timesteps = np.random.choice(goal_range, prompt_length-1, replace=False).tolist()
            goal_timesteps.sort()
        goal_timesteps.append(trajectory['optimal_prompts'].shape[0]-1)
    else: # sample (len-1) goals + the last goal in the trajectory as the prompt
        if prompt_length>1:
            goal_range = np.arange(1, trajectory['timesteps']-1)
            prompt_length = min(prompt_length, trajectory['timesteps']-1) # prompt len cannot exceed traj len
            #print('pl:', prompt_length)
            goal_timesteps = np.random.choice(goal_range, prompt_length-1, replace=False).tolist()
            goal_timesteps.sort()
        goal_timesteps.append(trajectory['timesteps']-1) # the last goal in the prompt is the task-goal
    #print(prompt_length, goal_timesteps)

    # padding to the left
    goal_timesteps = [0]*(max_prompt_length-prompt_length) + goal_timesteps # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    for t in goal_timesteps:
        if use_optimal_prompt and 'optimal_prompts' in trajectory:
            prompt.append(trajectory['optimal_prompts'][t])
        else:
            #print(trajectory['next_observations'][t][0:2]*15, trajectory['goal_pos'])
            prompt.append(trajectory['next_observations'][t][0:15])
    prompt = np.array(prompt).reshape(1,-1,15)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask


# compute the upperbound return for each task specified in the trajectory 
# (to compare with different methods). Input a list of trajectories, return a list 
# of return (float) of the task in each trajectory.

# def get_oracle_returns(trajectories, reward_step_penalty=-0.1):
#     ret = []
#     for t in trajectories:
#         l = len(t['optimal_prompts'])
#         ret.append(1+l*reward_step_penalty)
#     return ret


def get_avoid_prompt_old(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        if len(trajectory["avoid_states"].shape) == 2:
            prompt_length = min(trajectory["avoid_states"].shape[0], max_avoid_prompt_length)
        else:
            prompt_length = 0


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    if prompt_length > 0:
        for t in goal_timesteps:
            prompt.append(trajectory["avoid_states"][t,:])
    else:
        for t in goal_timesteps:
            prompt.append(np.array([0, 0, 0]))
    prompt = np.array(prompt).reshape(1,-1,3)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask

def get_avoid_prompt_old2(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        prompt_length = np.random.randint(1,max_avoid_prompt_length+1)


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    observations = trajectory['observations']
    start_pos = observations[0][:3]
    end_pos = observations[-1][:3]

    prompt = []
    for t in goal_timesteps:
        random_point_along_axis = start_pos + (start_pos - end_pos) * np.random.rand()
        random_size = np.random.rand() * 0.1
        center_x, center_y, center_z = random_point_along_axis[:3]
        # corners = [[center_x + a, center_y + b, center_z + c] for a in displacements for b in displacements for c in displacements]
        avoid_box = np.array([center_x - random_size, center_y - random_size, center_z - random_size, center_x + random_size, center_y + random_size, center_z + random_size])
        prompt.append(avoid_box)
    
    prompt = np.array(prompt).reshape(1,-1,6)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask

def get_avoid_prompt(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        if len(trajectory["avoid_states"].shape) == 2:
            prompt_length = min(trajectory["avoid_states"].shape[0], max_avoid_prompt_length)
        else:
            prompt_length = 0


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    if prompt_length > 0:
        for t in goal_timesteps:
            prompt.append(trajectory["avoid_states"][t,:])
    else:
        for t in goal_timesteps:
            prompt.append(np.array([0] * 30))
    prompt = np.array(prompt).reshape(1,-1,30)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    success = trajectory["success"]
    #print(prompt, goal_timesteps, mask)
    return prompt, mask, success


if __name__=='__main__':
    '''
    with open('envs/mazerunner/mazerunner-d15-g1-astar.pkl', 'rb') as f:
        trajectories = pickle.load(f)

    main_keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals']
    info_keys = ['timesteps', 'maze', 'goal_pos']
    print(len(trajectories))

    for k in main_keys:
        print(k, type(trajectories[0][k]), trajectories[0][k].dtype, trajectories[0][k].shape)

    for k in info_keys:
        print(k, trajectories[0][k])
    '''

    # import torch
    # device=torch.device('cuda')
    # info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list = \
    #     get_train_test_dataset_envs('mazerunner-d30-g4-4-t500-multigoal-astar.pkl', device, max_ep_len=500)
    # #for i in range(100):
    # #    get_prompt(trajectories_list[0], device=device)
    # #variant={'K':50, 'batch_size': 16, 'max_prompt_len': 5}
    # #fn = get_prompt_batch(trajectories_list, info[0], variant)
    # #prompt, batch = fn()

    # for env, traj in zip(test_env_list, test_trajectories_list):
    #     env.reset()
    #     prompts = (traj['optimal_prompts']*env.maze_dim).astype(int)
    #     for p in prompts:
    #         env.pos = p            
    #         env.render()
    
