
from random import random
import numpy as np

from simple_rl.mdp import MDPDistribution

from llrl.envs.gridworld import GridWorld

def sample_tight(gamma, env_name, version, w, h, stochastic, verbose, walls_num=30):
    """
    Sample a tight environment.
    :param gamma: (float)
    :param env_name: (str)
    :param version: (int)
    :param w: (int)
    :param h: (int)
    :param stochastic: (bool)
    :param verbose: (bool)
    :return: (GridWorld) Tight environment.
    """
    env_name = 'tight' if env_name is None else env_name
    r_min = 0.9
    r_max = 1.0

    goals_locations = [(w, h), (w, h - 1), (w - 1, h), (0,0)]
    # walls = [(np.random.randint(0, w - 1), np.random.randint(0, h - 1)) for _ in range(walls_num)]
    walls = [(w-1, h-1), (w-1, h-2), (w-2, h-1)]

    # walls.append((int(w / 2.) + 1, h))
    # walls.append((w, int(h / 2.) + 1))
    for i in range(2, int(w / 2.)-3):
        walls.append((int(w / 2.) + i, int(h / 2.) + i))
    
    if random() > 0.5:
        for i in range(2, int(w / 2.)-3):
            walls.append((int(w / 2.) + 1, int(h / 2.) + i))
    else:
        for i in range(2, int(w / 2.)-3):
            walls.append((int(w / 2.) + i, int(h / 2.) + 1))

    for wall in walls:
        if wall in goals_locations:
            walls.remove(wall)
        for wall_ in walls:
            if wall == wall_:
                walls.remove(wall)
    init_loc = (int(w / 2.) + 1, int(h / 2.) + 1)
    slip = np.random.uniform(0.0, 0.2) if stochastic else 0.0
    # print()

    if version == 1:
        is_goal_terminal = True
        goals = [goals_locations[np.random.randint(0, len(goals_locations))]]
        # rewards = [1.]
        rewards = np.random.uniform(low=r_min, high=r_max, size=len(goals))
    elif version == 2:
        print('Tight version 2')
        is_goal_terminal = False
        goals = goals_locations
        rewards = np.random.uniform(low=r_min, high=r_max, size=len(goals))
        rewards[-1]= rewards[-1] / 2
    else:
        raise ValueError('Tight version not implemented ( version =', version, ')')
    

    env = GridWorld(width=w, height=h, init_loc=init_loc, goal_locs=goals, gamma=gamma, slip_prob=slip,
                    goal_rewards=rewards, name=env_name, is_goal_terminal=is_goal_terminal, walls=walls)

    if verbose:
        print('Sampled tight:')
        print('  Goals:', goals)
        print('  Initial location:', init_loc)
        print('  Rewards:', rewards)
        print('  Slip probability:', slip)

    return env




def make_env_distribution(
        env_class='grid-world',
        env_name=None,
        n_env=10,
        gamma=.9,
        version=1,
        w=5,
        h=5,
        stochastic=False,
        horizon=0,
        verbose=True
):
    """
    Create a distribution over environments.
    This function is specialized to the included environments.
    :param env_class: (str) name of the environment class
    :param env_name: (str) name of the environment for save path
    :param n_env: (int) number of environments in the distribution
    :param gamma: (float) discount factor
    :param version: (int) in case a version indicator is needed
    :param w: (int) width for grid-world
    :param h: (int) height for grid-world
    :param horizon: (int)
    :param verbose: (bool) print info if True
    :param stochastic: (bool) some environments may be stochastic
    :return: (MDPDistribution)
    """
    if verbose:
        print('Creating environments of class', env_class)

    sampling_probability = 1. / float(n_env)
    env_dist_dict = {}

    for _ in range(n_env):
        if env_class == 'grid-world':
            pass
            # new_env = sample_grid_world(gamma, env_name, w, h, verbose)
        elif env_class == 'tight':
            new_env = sample_tight(gamma, env_name, version, w, h, stochastic, verbose)
    
        env_dist_dict[new_env] = sampling_probability
    return MDPDistribution(env_dist_dict, horizon=horizon)
