import torch
from a2c_ppo_acktr.model import Policy
import numpy as np
from tqdm import tqdm
from itertools import combinations

from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld, OvercookedState, ObjectState, SoupState
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.mdp.actions import Action, Direction


VAR_COUNTER = ['onion', 'dish']
VAR_POT = [1, 2]


def get_feasible_stat(mdp: OvercookedGridworld, *args, **kwargs):
    grid = mdp.terrain_mtx
    len_x = len(grid)
    len_y = len(grid[0])
    mdp.get_counter_locations()
    terrain_pos_dict = mdp.terrain_pos_dict
    standard_start_state = mdp.get_standard_start_state()

    all_empty_counter = terrain_pos_dict['X']
    all_empty_floor = terrain_pos_dict[' ']
    # all_pot = terrain_pos_dict['P']

    # all_reachable_floor = set()
    # all_reachable_floor.update(standard_start_state.player_positions)
    # remain_floor = set(all_empty_floor.copy())
    # # remain_floor = remain_floor.difference(all_reachable_floor)
    # stop_flag = False
    # while not stop_flag:
    #     stop_flag = True
    #     for ele in remain_floor:
    #         adj = [Action.move_in_direction(ele, act) for act in Action.MOTION_ACTIONS]
    #         adj = [pos for pos in adj if mdp.get_valid_player_positions()]

    all_reachable_floor = set(all_empty_floor.copy())
    # --------

    all_reachable_counter = []
    for counter in all_empty_counter:
        adj = [Action.move_in_direction(counter, act) for act in Action.MOTION_ACTIONS]
        if len(all_reachable_floor.intersection(set(adj))) > 0:
            all_reachable_counter.append(counter)

    stat_dict = {
        'all_reachable_counter': all_reachable_counter,
        'all_empty_counter': all_empty_counter
    }

    return stat_dict


def get_reverse_grad(policy: Policy, obs, rnn_hxs=None):
    # print(my_obs, my_obs.shape)
    obs = torch.tensor(obs, dtype=torch.float32)
    obs = obs.detach()

    my_obs = obs.clone()
    if my_obs.shape[-1] == 26:
        my_obs = torch.permute(my_obs, [2, 0, 1])

    my_obs = my_obs.unsqueeze(0)
    my_obs = torch.autograd.Variable(my_obs)
    my_obs.requires_grad_()

    if rnn_hxs is None:
        rnn_hxs = torch.zeros(1, policy.recurrent_hidden_state_size)
    v, feats, rnn_hxs = policy.base(my_obs, rnn_hxs, None)

    dist = policy.dist(feats)
    probs = dist.probs
    # probs[0].backward(torch.ones_like(probs[0]))
    # probs[0][0].backward()

    grads = []
    for i in range(6):
        probs[0][i].backward(retain_graph=(False if i == 5 else True))
        grad = my_obs.grad.detach().clone().numpy()
        grads.append(grad)
        my_obs.grad.data.zero_()

    grads = np.concatenate(grads)
    return grads, probs.detach().numpy()


def twist_start_state_all(mdp: OvercookedGridworld, start_state: OvercookedState, chosen_pos):
    if isinstance(chosen_pos, tuple):
        chosen_pos = [chosen_pos]

    st = start_state.deepcopy()
    all_twisted_states = [st]
    for pos in chosen_pos:
        tmp_twisted_states = []
        for based_state in all_twisted_states:
            pos_type = mdp.get_terrain_type_at_pos(pos)
            if pos_type == 'X':
                for obj_type in VAR_COUNTER:
                    twisted_state = based_state.deepcopy()
                    obj = ObjectState(obj_type, pos)
                    twisted_state.add_object(obj)
                    tmp_twisted_states.append(twisted_state)
            elif pos_type == 'P':
                for onion_num in VAR_POT:
                    twisted_state = based_state.deepcopy()
                    soup_obj = SoupState(pos, ingredients=[ObjectState('onion', pos) for i in range(onion_num)])
                    twisted_state.add_object(soup_obj)
                    tmp_twisted_states.append(twisted_state)
        # all_twisted_states += tmp_twisted_states

        all_twisted_states = tmp_twisted_states.copy()

    # all_twisted_states.remove(st)

    return all_twisted_states


def feasible_set_to_choice(env: OvercookedEnv, stat_dict, epi=2):
    all_reachable_counter = stat_dict['all_reachable_counter']
    all_pots = env.mdp.get_pot_locations()

    all_manipulable_pos = all_reachable_counter + all_pots
    all_twisted_states = []

    pos_combs = combinations(all_manipulable_pos, epi)
    standard_start_state = env.mdp.get_standard_start_state()
    for chosen_pos in tqdm(pos_combs):
        all_twisted_states += twist_start_state_all(env.mdp, standard_start_state, chosen_pos=list(chosen_pos))

    standard_obs_0 = env.lossless_state_encoding_mdp(standard_start_state)[0]
    diff = [env.lossless_state_encoding_mdp(twisted_state)[0] - standard_obs_0 for twisted_state in all_twisted_states]

    if epi > 1:
        lower_level_states, lower_level_diff = feasible_set_to_choice(env, stat_dict, epi - 1)
        all_twisted_states += lower_level_states
        diff += lower_level_diff

    return all_twisted_states, diff


def combine_units(env: OvercookedEnv, unit_perturbations, epi=2):
    all_twisted_states = []
    up_combs = combinations(unit_perturbations, epi)
    standard_start_state = env.mdp.get_standard_start_state()
    for up_comb in up_combs:
        start_state = None
        for up in tqdm(up_comb):
            if start_state is None:
                start_state = up.deepcopy()
            else:
                old_objs = start_state.objects
                new_objs = up.objects
                for k in new_objs:
                    if k not in old_objs:
                        obj = new_objs[k].deepcopy()
                        start_state.add_object(obj)
            if len(start_state.objects) == epi:
                all_twisted_states.append(start_state)

    standard_obs_0 = env.lossless_state_encoding_mdp(standard_start_state)[0]
    diff = [env.lossless_state_encoding_mdp(twisted_state)[0] - standard_obs_0 for twisted_state in all_twisted_states]

    if epi > 1:
        lower_level_states, lower_level_diff = combine_units(env, unit_perturbations, epi - 1)
        all_twisted_states += lower_level_states
        diff += lower_level_diff

    return all_twisted_states, diff


def state_filter(all_twisted_states, diff, obss, threshold=0.01):

    cascade_obss = np.sum(obss, axis=0)
    good_idx = []

    for i, twisted_state in enumerate(all_twisted_states):
        if np.sum(np.multiply(diff[i], cascade_obss)) < obss.shape[0] * threshold:
            good_idx.append(i)

    return good_idx


# def get_irrational_actions(mdp: OvercookedGridworld, state: OvercookedState, ai_actions):
#     # irrational action: no effect and not the target
#     irrational_actions = []
#     for i, player in state.players:
#         bad_actions = []
#         for action in Action.ALL_ACTIONS:
#             # use math sqrt to make probability proportional to area of the image
#             size = math.sqrt(probs[Action.ACTION_TO_INDEX[action]])
#             if action == "interact":
#                 img = pygame.transform.rotozoom(rescaled_interact, 0, size)
#                 self._render_on_tile_position(surface, img, player.position, horizontal_align="left",
#                                               vertical_align="center")
#             elif action == Action.STAY:
#                 img = pygame.transform.rotozoom(rescaled_stay, 0, size)
#                 self._render_on_tile_position(surface, img, player.position, horizontal_align="right",
#                                               vertical_align="center")
#             else:
#                 position = Action.move_in_direction(player.position, action)
#                 img = pygame.transform.rotozoom(rescaled_arrow, direction_to_rotation[action], size)
#                 self._render_on_tile_position(surface, img, position, **direction_to_aligns[action])
#         if
