import pickle
import torch
import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.image as mpimg

MAZE_SHAPES = {
    "antmaze-umaze-v2": "#####\\#G00#\\###0#\\#000#\\#####",
    "antmaze-medium-v2": "########\\#00##00#\\#00#000#\\##000###\\#00#000#\\#0#00#0#\\#000#0G#\\########",
    "antmaze-medium-diverse-v2": "########\\#00##00#\\#00#000#\\##000###\\#00#000#\\#0#00#0#\\#000#0G#\\########",
    "antmaze-medium-play-v2": "########\\#00##00#\\#00#000#\\##000###\\#00#000#\\#0#00#0#\\#000#0G#\\########",
    "antmaze-large-diverse-v2": "############\\#0000#00000#\\#0##0#0#0#0#\\#000000#000#\\#0####0###0#\\#00#0#00000#\\##0#0#0#0###\\#00#000#0G0#\\############",
    "antmaze-large-play-v2": "############\\#0000#00000#\\#0##0#0#0#0#\\#000000#000#\\#0####0###0#\\#00#0#00000#\\##0#0#0#0###\\#00#000#0G0#\\############",
    "AntMaze_Large_Diverse_GR-v4": "############\\#0000#00000#\\#0##0#0#0#0#\\#000000#000#\\#0####0###0#\\#00#0#00000#\\##0#0#0#0###\\#00#000#0G0#\\############",
    "antmaze-ultra-diverse-v0": "################\\#R000000#000000#\\#0###0#0#0##0#0#\\#0###0#0000#0#0#\\#000#0##0###0#0#\\#0#000#00000000#\\#000#000#0###0##\\#0#0###0#000#0##\\#0000000#0#0000#\\##0##0#000####0#\\##0#00#0#00000G#\\################", 
    "antmaze-ultra-play-v0": "################\\#R000000#000000#\\#0###0#0#0##0#0#\\#0###0#0000#0#0#\\#000#0##0###0#0#\\#0#000#00000000#\\#000#000#0###0##\\#0#0###0#000#0##\\#0000000#0#0000#\\##0##0#000####0#\\##0#00#0#00000G#\\################",
    "antmaze-extreme-diverse-v0": "########################\\#0000000000##000#000#00#\\#0###0#0##0000#000#0#0##\\#0###0#0000######00000##\\#00000##0#00000##0#0#00#\\#0###00#0#0###0000#0##0#\\#00###0#0000000#0##0000#\\#000000#0##0##0#0#####0#\\#0##0#0#0##0##0#0##0000#\\#0##0#0#0000000000#0##0#\\#0000#00000######0#0000#\\#0##0#####000000#0######\\#0##0#00000####00000000#\\#00#0#####0000####0###0#\\#00#0000000##0000000000#\\########################\\",
    "antmaze-extreme-play-v0": "########################\\#0000000000##000#000#00#\\#0###0#0##0000#000#0#0##\\#0###0#0000######00000##\\#00000##0#00000##0#0#00#\\#0###00#0#0###0000#0##0#\\#00###0#0000000#0##0000#\\#000000#0##0##0#0#####0#\\#0##0#0#0##0##0#0##0000#\\#0##0#0#0000000000#0##0#\\#0000#00000######0#0000#\\#0##0#####000000#0######\\#0##0#00000####00000000#\\#00#0#####0000####0###0#\\#00#0000000##0000000000#\\########################\\",
}

MAZE_RANGE = {
    "antmaze-medium-diverse-v2": [[-2, 22.5], [-2, 22.5]],
    "antmaze-medium-play-v2": [[-2, 22.5], [-2, 22.5]],
    "antmaze-large-diverse-v2": [[-2, 40], [-2, 30]],
    "antmaze-large-play-v2": [[-2, 40], [-2, 30]],
    "AntMaze_Large_Diverse_GR-v4": [[-2, 40], [-2, 30]],
    "antmaze-ultra-diverse-v0": [[-2, 55], [-2, 40]],
    "antmaze-ultra-play-v0": [[-2, 55], [-2, 40]],
    "antmaze-extreme-play-v0": [[-2, 90], [-2, 55]],
    "antmaze-extreme-diverse-v0": [[-2, 90], [-2, 55]],
}

COLORS=["red", "green", "blue", "cyan", "magenta", "yellow", "black", "deeppink", "darkmagenta", "teal", "saddlebrown", "rosybrown",
        "indianred", "lightgreen", "olive", "deepskyblue", "skyblue", "turquoise", "lavender", "lightgrey", 
        "lightpink", "crimson", "mediumpurple", "lightsteelblue", "slategrey", "whitesmoke", "antiquewhite", "royalblue", 
        "bisque", "azure", "cadetblue", "aliceblue", "khaki",
        "gold", "coral", "chocolate", "darkblue", "darkcyan", "darkgoldenrod", "darkgreen", 
        "darkkhaki", "darkolivegreen", "darkorange", "darkorchid", "darksalmon", 
        "darkseagreen", "darkslateblue", "darkslategrey", "darkturquoise", "darkviolet", 
        "dodgerblue", "firebrick", "forestgreen", "fuchsia", "gainsboro", "ghostwhite", 
        "goldenrod", "greenyellow", "honeydew", "hotpink", "ivory", "lavenderblush", 
        "lawngreen", "lemonchiffon", "lightblue", "lightcoral", "lightcyan", "lightgoldenrodyellow", 
        "lightseagreen", "lightsalmon", "lightyellow", "lime", "limegreen", "linen", 
        "maroon", "mediumaquamarine", "mediumblue", "mediumorchid", "mediumpurple", 
        "mediumseagreen", "mediumslateblue", "mediumspringgreen", "mediumturquoise", 
        "mediumvioletred", "midnightblue", "mintcream", "mistyrose", "moccasin", "navajowhite", 
        "navy", "oldlace", "olivedrab", "orange", "orangered", "orchid", "palegoldenrod", 
        "palegreen", "paleturquoise", "palevioletred", "papayawhip", "peachpuff", "peru", 
        "pink", "plum", "powderblue", "purple", "rebeccapurple", "rosybrown", "royalblue", 
        "saddlebrown", "salmon", "sandybrown", "seagreen", "seashell", "sienna", "silver", 
        "snow", "springgreen", "steelblue", "tan", "thistle", "tomato", "violet", "wheat", 
        "white", "yellowgreen"]

def plot_maze(tokenizer, env_name, device, resolution=1, background_file=None, points=None, classes=None, labels=None):
    map = MAZE_SHAPES[env_name]
    env = []
    for row in map.split("\\"):
        env.append(list(row))

    _, ax = plt.subplots()
    ax.set_aspect('equal')

    background = _get_tokens_background(tokenizer, env_name, device, resolution, background_file)
    x_range, y_range = MAZE_RANGE[env_name]
    ax.imshow(background, extent=(x_range[0], x_range[1], y_range[0], y_range[1]), origin='lower', alpha=.75, interpolation="bilinear")
    for y, row in enumerate(env):
        for x, char in enumerate(row):
            if char == '#':
                ax.add_patch(plt.Rectangle((4*x - 5.5, 4*y - 5.5), 4, 4, color='black'))
    
    if points is not None:
        for data in points:
            ax.scatter(data[:,0], data[:, 1])
        
    if classes is not None:
        for i, class_ in enumerate(classes.values()):
            class_ = np.array(class_)
            ax.scatter(class_[:, 0], class_[:, 1], c=COLORS[i])

    if tokenizer is not None:
        labels = [[] for _ in range(tokenizer._number_of_tokens)]
        for key, val in zip(tokenizer._tokens.keys(), tokenizer._tokens.values()):
            labels[val] = np.array(key)
        for label, position in enumerate(labels):
            if len(position) == 2:
                plt.text(position[0], position[1], label, ha='center', va='center')

    ax.autoscale_view()
    ax.axis('off')
    plt.tight_layout()
    plt.savefig(background_file + '.pdf')

def plot_godot(tokenizer, device, resolution, background_file, points=None):
    _, ax = plt.subplots()

    background = _get_tokens_background(tokenizer, "map1_large", device, resolution, background_file, keys_to_tokenize='sensor/position')
    ax.imshow(background, extent=[-60, 60, -60, 60], origin='lower', alpha=.7, interpolation="bilinear")

    img = mpimg.imread('godot_map.png')
    ax.imshow(img[::-1, :], extent=[-60, 60, -60, 60])
    if points is not None:
        for data in points:
            ax.scatter(data[:, 0], data[:, 1])

    if tokenizer is not None:
        labels = [[] for _ in range(tokenizer._number_of_tokens)]
        for key, val in zip(tokenizer._tokens.keys(), tokenizer._tokens.values()):
            labels[val] = np.array(key)
        for label, position in enumerate(labels):
            if len(position) == 2:
                plt.text(position[0], position[1], label, ha='center', va='center')
                
    ax.autoscale_view()
    ax.axis('off')
    plt.tight_layout()
    plt.savefig(background_file + '.pdf')


def _get_tokens_background(tokenizer, env_name, device, resolution = 1, background_file=None, keys_to_tokenize ='obs/pos'):
    if background_file is not None and os.path.exists(background_file):
        with open(background_file, 'rb') as f:
            background = pickle.load(f)
        return background 
    classes = {}

    x_range, y_range = MAZE_RANGE[env_name] if env_name in MAZE_RANGE else [[-60, 60], [-60, 60]]
    x_values = np.arange(x_range[0], x_range[1] + resolution, resolution)
    y_values = np.arange(y_range[0], y_range[1] + resolution, resolution)
    
    X, Y = np.meshgrid(x_values, y_values)
    def f(x,y):
        input = torch.tensor([[x,y]], device=device).float()
        
        output = tokenizer({keys_to_tokenize: input.unsqueeze(0)})['token/ids']
        output = (output.item())

        if output not in classes:
            classes[output] = [int(x * 255) for x in (mcolors.to_rgba(COLORS[output % len(COLORS)])[:3])]
        return classes[output]
    Z = np.array([[f(x, y) for x, y in zip(x_row, y_row)] for x_row, y_row in zip(X, Y)])

    if background_file is not None:
        try:
            with open(background_file, 'wb') as f:
                pickle.dump(Z, f)
        except Exception as e:
            print(f"Error occurred while writing to file: {e}")

    return Z

######################## For images
import imageio
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, IterableDataset

class VisualAntmazeDataset(Dataset):
    def __init__(
        self,
        data_path = 'datasets/test_data',
    ):
        self.data_path = os.path.abspath(data_path)
        self.file_names = os.listdir(data_path)
        self.n_imgs = len(self.file_names)
        
    def __getitem__(self, i):
        name = self.file_names[i]
        x,y = self.file_names[i].replace('.png', '').split('_')
        image = np.array(imageio.imread(os.path.join(self.data_path,self.file_names[i]))[:,:,:3]).astype(np.float32)
        image = image / 255
        image = image.transpose(2,0,1)
        return image, (x,y)
    
    def __len__(self):
        return len(self.file_names)


def _get_images_tokens_background(tokenizer, dataset, env_name, background_file=None):
    if background_file is not None and os.path.exists(background_file):
        with open(background_file, 'rb') as f:
            background = pickle.load(f)
        return background 
    classes = {}

    x_range, y_range = MAZE_RANGE[env_name]
    x_values = np.arange(x_range[0], x_range[1], 1)
    y_values = np.arange(y_range[0], y_range[1], 1)
    
    X, Y = np.meshgrid(x_values, y_values)
    Z = np.zeros(shape=(X.shape[0],X.shape[1],3),dtype=np.int32)

    for i in tqdm(range(len(dataset))):
        image, (ix, iy) = dataset[i]
        ix, iy = int(ix), int(iy)
        input = torch.from_numpy(image).to(device='cuda').unsqueeze(0).permute(0, 2, 3, 1)
        output = tokenizer({'observations': input.unsqueeze(0)})['token/ids']
        output = (output.item())

        if output not in classes:
            classes[output] = [int(x * 255) for x in (mcolors.to_rgba(COLORS[output])[:3])]

        Z[ix, iy] = classes[output]

    if background_file is not None:
        try:
            with open(background_file, 'wb') as f:
                pickle.dump(Z, f)
        except Exception as e:
            print(f"Error occurred while writing to file: {e}")

    return Z, X, Y

def plot_visual_maze(tokenizer, background_file=None, points=None, classes=None):
        """
        Plots the maze with tokenized positions for antmaze
        """
        dataset = VisualAntmazeDataset()
        env_name = 'antmaze-large-diverse-v2'
        map = MAZE_SHAPES['antmaze-large-diverse-v2']
        env = []
        for row in map.split("\\"):
            env.append(list(row))

        # generate_test_images(tokenizer, episodes_reader, env_name, device, num_episodes, resolution, background_file)
        Z, X, Y = _get_images_tokens_background(tokenizer, dataset, env_name, background_file)
        x_range, y_range = MAZE_RANGE[env_name]
        _, ax = plt.subplots()
        ax.set_aspect('equal')
        img = ax.imshow(Z, extent=(x_range[0], x_range[1], y_range[0], y_range[1]), origin='lower', alpha=.75, interpolation="bilinear")
        for y, row in enumerate(env):
            for x, char in enumerate(row):
                if char == '#':
                    ax.add_patch(plt.Rectangle((4*x - 5.5, 4*y - 5.5), 4, 4, color='black'))
        
        if points is not None:
            for data in points:
                ax.scatter(data[:,0], data[:, 1])
            
        if classes is not None:
            for i, class_ in enumerate(classes.values()):
                class_ = np.array(class_)
                ax.scatter(class_[:, 0], class_[:, 1], c=COLORS[i])

        ax.autoscale_view()
        ax.axis('off')
        plt.tight_layout()
        plt.savefig('tokenized_map.png')


