from enum import Enum
# import threading
import numpy as np
import matplotlib.pyplot as plt

from ..ataritools.actions import actions
from .options import options
from .plans import Plans
from collections import deque

ladders = Enum('ladders', start=0, names=['ANY', 'TOP', 'LEFT', 'RIGHT'])

enemy_detection_message = "Enemy spotted!\nTake evasive action!"

def find_nearest(array, value):
    # Find the element in 'array' that is the closest to 'value'
    idx = (np.abs(np.asarray(array) - value)).argmin()
    return array[idx]

def pixelsSame(a, b):
    return a[0] == b[0] and a[1] == b[1] and a[2] == b[2]

def isPixelBlack(pixel, frame):
    real_pixel = frame[pixel[0]][pixel[1]]
    return real_pixel[0] == 0 and real_pixel[1] == 0 and real_pixel[2] == 0

def allPixelsBlack(pixels):
    for row in range(len(pixels)):
        for cell in range(len(pixels[0])):
            if not isPixelBlack((row, cell), pixels):
                return False
    return True

def allPixelsSame(pixels):
    color = pixels[0, 0]
    for row in range(len(pixels)):
        for cell in range(len(pixels[0])):
            if not pixelsSame(color, pixels[row, cell, :]):
                return False
    return True

def borderNonBlackInnerBlack(pixels):
    #should only be used for ladder detection
    if not len(pixels[0]) == 10:
        return False
    if not len(pixels) == 7:
        return False
    borderColor = pixels[0][0]
    innerColor = pixels[1][1]

    if leq(innerColor, borderColor):
        return False
    for row in range(len(pixels)):
        for cell in range(len(pixels[0])):
            if row != 0 and row != len(pixels) - 1 and cell != 0 and cell != len(pixels[0]) - 1:
                if not leq(innerColor, pixels[row][cell]):
                    return False
            elif not leq(borderColor, pixels[row][cell]):
                return False
    return True

def leq(l1, l2):
    if len(l1) != len(l2):
        return False
    for elt in range(len(l1)):
        if not l1[elt] == l2[elt]:
            return False
    return True

def remove_columns(pixels):
    for col in range(0, len(pixels[0])):
        allOneColor = True
        for row in range(0, len(pixels)):
            if (not leq(pixels[0][col], pixels[row][col])):
                allOneColor = False
                break
        if (allOneColor):
            for row in range(0, len(pixels)):
                pixels[row][col] = [0, 0, 0]

def remove_rectangle(pixels, start):

    # start is the coordinates of a pixel that determines the color of the rectangle
    # by convention, this is the top left corner of the rectangle
    end_row = 0
    end_col = 0
    color = pixels[start[0]][start[1]]

    # next 2 loops identify the largest potential rectangle height and width

    for row in range(start[0], len(pixels)):
        end_row = row
        if not leq(color, pixels[row][start[1]]):
            end_row -= 1
            break
    for col in range(start[1], len(pixels[0])):
        end_col = col
        if not leq(color, pixels[start[0]][col]):
            end_col -= 1
            break

    # verify that we have an actual rectangle (based on shape)
    if (end_row == start[0] and end_col == start[1]) or \
       ((end_row == start[0] and not end_row == len(pixels) - 1) or \
       (end_col == start[1] and not end_col == len(pixels[0]) - 1)):
        if not end_col == 0 or end_col == len(pixels[0]):
            return

    # if end_col == len(pixels[0]) - 1:
    #     end_col += 1
    # if end_row == len(pixels) - 1:
    #     end_row += 1

    # Now, we want to verify that all pixels in our rectangular region are the
    # same color.
    for row in range(start[0], end_row + 1):
        for col in range(start[1], end_col + 1):
            if not leq(color, pixels[row][col]):
                # our rectangle wasn't uniform color, bail out
                return

    # we've found a rectangle we can remove, modify the pixel array to set the
    # rectangle pixels to black.
    for row in range(start[0], end_row + 1):
        for col in range(start[1], end_col + 1):
            pixels[row][col] = [0, 0, 0]

def show(pixels):
    plt.imshow(pixels)
    plt.show()

class SkillController:
    """Controller for hand-coded skills
    """
    def __init__(self, initial_plan=None):
        """Initialize SkillController (optionally by specifying a plan)

        If provided, 'plan' should be a collections.deque of options
        """
        self.option = options.NONE
        self.frame = 0
        self.noop_max = 180
        self.noop_count = 0
        self.skillFunction = {
            options.NONE: self.noop,
            options.RUN_LEFT: self.runLeft,
            options.RUN_RIGHT: self.runRight,
            options.RUN_LEFT3: self.runLeft3,
            options.RUN_RIGHT3: self.runRight3,
            options.JUMP_LEFT: self.jumpLeft,
            options.JUMP_RIGHT: self.jumpRight,
            options.JUMP: self.jumpUp,
            options.CLIMB_UP: self.climbUp,
            options.CLIMB_DOWN: self.climbDown,
            options.WAIT_FOR_SKULL: self.waitForSkull,
            options.WAIT_1: self.wait1,
            options.WAIT_5: self.wait5,
            options.WAIT_10: self.wait10,
            options.STEP_RIGHT: self.stepRight,
            options.STEP_LEFT: self.stepLeft,
            options.SAVE: self.save,
            options.LOAD: self.load
        }
        self.nQueuedSkillsExecuted = 0
        self.initial_plan = initial_plan if initial_plan else Plans.GetDefaultStart()
        self.isInitialized = False
        self.lastX = deque([0] * 10)
        self.isActualRun = True
        self.wasWaiting = False

    def getQueuedSkill(self, state):
        # Attempt to run the skill that is next in the queue
        try:
            o = self.initial_plan.popleft()
            (canExecute, _, _) = self.skillFunction[o](state)
            self.option = o
            # if canExecute:
            #     self.option = o
            # else:
            #     self.option = options.NONE
            #     raise ValueError
        except ValueError:
            print("Invaild skill.")
            raise
        except IndexError:
            print("Queue empty.")
            self.skillPrompt(state)
        # If no error was raised, increment counter
        self.nQueuedSkillsExecuted += 1
        return

    def getValidSkills(self, state):
        """Return the list of skills that are valid in the specified state

        This runs the skills' implicit classifiers
        """
        valid_ops = np.zeros(len(options))
        self.isActualRun = False
        for idx, op in enumerate(list(options)):
            (canExecute, _, _) = self.skillFunction[op](state)
            if canExecute:
                valid_ops[idx] = 1
        self.isActualRun = True
        return valid_ops

    def getRandomSkill(self, state):
        # Choose the next skill randomly from the valid skills in the specified state
        valid_ops = self.getValidSkills(state)
        op_id = (options.NONE.value +
                 np.random.choice(len(options), p=(valid_ops / np.sum(valid_ops))))
        self.option = options(op_id)

    def skillPrompt(self, state):
        # Ask user to specify a skill to run next
        while self.option == NONE:
            skill = input('''
    Select a skill:
    1. RUN_LEFT
    2. RUN_RIGHT
    3. JUMP_LEFT
    4. JUMP_RIGHT
    5. JUMP
    6. CLIMB_UP
    7. CLIMB_DOWN
    8. WAIT_FOR_SKULL
    > ''')
            try:
                skill = int(skill)
                if skill in range(1, 9):
                    print("You selected {}.".format(options(skill).name))
                    o = options(skill)
                    (canExecute, _, _) = self.skillFunction[o](state)
                    if canExecute:
                        self.option = o
                    else:
                        raise ValueError
            except ValueError:
                print("Invaild skill.")

    def isRunningSkill(self):
        if self.option != options.NONE:
            return True
        else:
            return False

    def chooseNextSkill(self, state):
        # self.skillPrompt(state)
        if self.initial_plan:
            try:
                self.getQueuedSkill(state)
            except ValueError:
                if self.nQueuedSkillsExecuted > 0:
                    self.initial_plan = None
                    self.getRandomSkill(state)
        else:
            self.isInitialized = True
            self.getRandomSkill(state)
        self.frame = 0

    def didSkillTimeout(self, action, op_valid):
        timeout = False
        if self.noop_count >= self.noop_max:
            self.noop_count = 0
            timeout = True
        elif not op_valid and action == actions.NOOP:
            # Count number of NOOPs in a row if the current skill is invalid
            self.noop_count += 1
        else:
            self.noop_count = 0
        return timeout

    def runSkillPolicy(self, state):
        """Run the current skill (or if it has completed, run the next skill)

        Returns a tuple consisting of (action, option, frame, op_done)
            action - the next low-level action
            option - the option that is currently running
            op_frame - the number of frames since the current option started
            op_done - whether this frame is the last one for this option
        """
        if not self.isRunningSkill():
            self.chooseNextSkill(state)

        skill_fn = self.skillFunction[self.option]
        (op_valid, action, op_done) = skill_fn(state)
        timeout = self.didSkillTimeout(action, op_valid)
        if not op_done and (state['respawned'] or timeout):
            action = actions.NOOP
            op_done = True

        action_frame_tuple = (action, self.option, self.frame, op_done)
        self.frame += 1
        if op_done:
            self.option = options.NONE
            self.frame = 0
        return action_frame_tuple

    def noop(self, state):
        canExecute = True  # unless another option can run
        for o in list(options)[1:]:
            (valid, _, _) = self.skillFunction[o](state)
            if valid:
                canExecute = False
                break
        action = actions.NOOP
        lastFrame = True
        return (canExecute, action, lastFrame)

    def save(self, state):
        if self.isActualRun:
            print("Type saveawd")
            filePath = "saves/" + input()
            state['env'].save(filePath)
        return self.noop(state)

    def load(self, state):
        if self.isActualRun:
            print("Type loading")
            filePath = "saves/" + input()
            state['env'].load(filePath)
        return self.noop(state)

    def runRight(self, state):
        return self.run(state, actions.RIGHT)

    def runLeft(self, state):
        return self.run(state, actions.LEFT)

    def runRight3(self, state):
        return self.run(state, actions.RIGHT, True)

    def runLeft3(self, state):
        return self.run(state, actions.LEFT, True)

    def wait1(self, state):
        return self.wait(state, 1)

    def wait5(self, state):
        return self.wait(state, 5)

    def wait10(self, state):
        return self.wait(state, 10)

    def wait(self, state, times):
        if self.isActualRun:
            if self.wasWaiting:
                remainingTimes = times - self.waitedAlready
            else:
                self.wasWaiting = True
                remainingTimes = times
                self.waitedAlready = 0
            if remainingTimes == 1:
                self.wasWaiting = False
            self.waitedAlready += 1
            return (True, actions.NOOP, not self.wasWaiting)
        else:
            return (True, actions.NOOP, False)

    def stepRight(self, state):
        return self.step(state, 1, actions.RIGHT)

    def stepLeft(self, state):
        return self.step(state, 1, actions.LEFT)

    def step(self, state, times, direction):
        if self.isActualRun:
            if self.wasWaiting:
                remainingTimes = times - self.waitedAlready
            else:
                self.wasWaiting = True
                remainingTimes = times
                self.waitedAlready = 0
            if remainingTimes == 1:
                self.wasWaiting = False
            self.waitedAlready += 1
            return (True, direction, not self.wasWaiting)
        else:
            return (True, direction, False)

    def run(self, state, direction, is3=False):
        x = state['player_x']
        y = state['player_y']
        if self.isActualRun:
            self.lastX.append(x)
            self.lastX.popleft()

        def thar_be_enemy(rgb, lbound, rbound):
            cut = rgb[31:40, lbound:rbound]
            cut = np.copy(cut)
            remove_columns(cut)
            for y in range(len(cut)):
                alll = True
                none = True
                for x in range(len(cut[0])):
                    if not isPixelBlack((y, x), cut):
                        remove_rectangle(cut, (y, x))
                    alll = alll and isPixelBlack((y, x), cut)
                    none = none and not (isPixelBlack((y, x), cut))
                if not (none or alll):
                    return True
            return False

        if self.isActualRun:
            rgb = state['env'].get_pixels_around_player(height=26, trim_direction=direction)

        on_ground = not (state['player_falling'] or state['player_jumping']) and (
            state['player_status'] in ['standing', 'running']) and not state['just_died']
        if not on_ground:
            return (False, direction, False)
        # Potentially add one more block of padding if running off the edge
        stop_condition = False
        if self.isActualRun:
            alll = True
            for prevX in self.lastX:
                if not x == prevX:
                    alll = False
            if alll:
                self.lastX = deque([0] * 10)
            stop_condition = stop_condition or alll
        if direction == actions.LEFT:

            #enemy stop_condition
            if self.isActualRun:
                if is3:
                    stop_condition = stop_condition or thar_be_enemy(rgb, 7, 15)
                else:
                    stop_condition = stop_condition or thar_be_enemy(rgb, 10, 15)

                #platform edge stop_condition
                cut = rgb[45:53, -6:, :]
                if len(cut[0]) > 0:
                    stop_condition = stop_condition or (allPixelsBlack(cut[2:6, 0:1, :]))

                #Ladder detection below and at level
                cut = rgb[45:51, -6:, :]
                if len(cut) == 7 and len(cut[0]) == 10:
                    stop_condition = stop_condition or borderNonBlackInnerBlack(cut)

                for i in range(12, 55):
                    cut = rgb[i:i + 7, 18:, :]
                    stop_condition = stop_condition or borderNonBlackInnerBlack(cut)
                    cut = rgb[i:i + 7, 17:-1, :]
                    stop_condition = stop_condition or borderNonBlackInnerBlack(cut)
            return (True, actions.LEFT, stop_condition)
        if direction == actions.RIGHT:
            #enemy stop_condition
            if self.isActualRun:
                if is3:
                    stop_condition = stop_condition or thar_be_enemy(rgb, 12, 20)
                else:
                    stop_condition = stop_condition or thar_be_enemy(rgb, 12, 17)

                #platform edge detection
                cut = rgb[45:53, :7, :]
                stop_condition = stop_condition or (allPixelsBlack(cut[2:6, 6:7, :]))

                #Ladder detection
                # cut = rgb[47:-2,:-18,:]
                # show(cut)
                # if len(cut) == 7 and len(cut[0]) == 10:
                #     stop_condition = stop_condition or borderNonBlackInnerBlack(cut)
                for i in range(43, 55):
                    cut = rgb[i:i + 7, 1:11, :]
                    if len(cut) == 7 and len(cut[0]) == 10:
                        stop_condition = stop_condition or borderNonBlackInnerBlack(cut)

                #This is wrong, fix it
                cut = rgb[17:24, 2:12, :]
                stop_condition = stop_condition or borderNonBlackInnerBlack(cut)
            return (True, actions.RIGHT, stop_condition)
        return (True, actions.LEFT, False)

    def jumpUp(self, state):
        return self.jump(state, actions.FIRE)

    def jumpRight(self, state):
        return self.jump(state, actions.RIGHT_FIRE)

    def jumpLeft(self, state):
        return self.jump(state, actions.LEFT_FIRE)

    def jump(self, state, direction):
        canExecute = False
        action = actions.NOOP
        lastFrame = False

        on_ground = (
            state['player_status'] in ['standing', 'running']
            and not (state['player_jumping'] or state['player_falling'] or state['just_died']))

        on_rope = (state['player_status'] in ['on-rope', 'climbing-rope']
                   and not (state['player_jumping'] or state['player_falling']))

        if on_ground or (on_rope and direction != actions.FIRE):
            # Assume we have yet to jump
            canExecute = True
            action = direction
            if (self.frame > 4
                    and self.option in [options.JUMP, options.JUMP_LEFT, options.JUMP_RIGHT]):
                # Override if we've already jumped + landed
                lastFrame = True
                action = actions.NOOP

        return (canExecute, action, lastFrame)

    def climbUp(self, state):
        return self.climb(state, actions.UP)

    def climbDown(self, state):
        return self.climb(state, actions.DOWN)

    def climb(self, state, direction):
        x = state['player_x']
        y = state['player_y']
        canExecute = False
        action = direction
        lastFrame = False

        rgb = state['env'].get_pixels_around_player(height=26, trim_direction=direction)
        ladder_below = False
        for i in range(46, 60):
            cut = rgb[i:i + 7, 12:22, :]
            if borderNonBlackInnerBlack(cut):
                ladder_below = True
                break
            cut = rgb[i:i + 7, 13:23, :]
            if borderNonBlackInnerBlack(cut):
                ladder_below = True
                break

        at_floor = allPixelsBlack(rgb[46:47, 8:11, :]) and allPixelsBlack(rgb[46:47, 21:24, :])
        at_floor = (at_floor or allPixelsBlack(rgb[47:48, 8:11, :])
                    and allPixelsBlack(rgb[47:48, 21:24, :]))
        #at_floor = at_floor or allPixelsSame(rgb[47:48,2:28,:])
        at_floor = (at_floor or allPixelsSame(rgb[45:46, 2:28, :])
                    or allPixelsSame(rgb[46:47, 2:28, :]))

        lastFrame = at_floor

        can_climb = (
            not (state['just_died'] or state['player_jumping'] or state['player_falling'])
            and state['player_status'] in [
                'standing', 'running', 'on-ladder', 'climbing-ladder', 'on-rope', 'climbing-rope']
        )# yapf: disable

        if can_climb and ladder_below:
            canExecute = True

        return (canExecute, action, lastFrame)

    def render(self, state):
        if self.isActualRun:
            rgb = state['env'].get_pixels_around_player(height=26, trim_direction=actions.RIGHT)
            show(rgb[46:47, 8:11, :])
            show(rgb[46:47, 21:24, :])
            show(rgb[46:47, 2:28, :])
        return (True, actions.NOOP, True)

    def waitForSkull(self, state):
        def thar_be_enemy(rgb, lbound, rbound):
            cut = rgb[31:40, lbound:rbound]
            cut = np.copy(cut)
            remove_columns(cut)
            for y in range(len(cut)):
                alll = True
                none = True
                for x in range(len(cut[0])):
                    if not isPixelBlack((y, x), cut):
                        remove_rectangle(cut, (y, x))
                    alll = alll and isPixelBlack((y, x), cut)
                    none = none and not (isPixelBlack((y, x), cut))
                if not (none or alll):
                    return True
            return False

        action = actions.NOOP
        lastFrame = False
        #if self.isActualRun:
        rgb = state['env'].get_pixels_around_player(height=26, trim_direction=actions.RIGHT)
        lastFrame = thar_be_enemy(rgb, 10, 15) or thar_be_enemy(rgb, 12, 17)

        return (True, actions.NOOP, lastFrame)
