import math
import numpy as np
import cvxpy as cp
import torch
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

USE_MINI_MAP = False

MINI_MAP = np.array(
            [[0,    0,    0,   0,   0,  0],
             [0.7,  0,    0,   0,   0,  0],
             [-10,  8.2,  0,   0,   0,  0],
             [-10, -10, 11.5,  0,   0,  0],
             [-10, -10, -10, 14.0, 15.1,16.1],
             [0,    0,    0,   0,   0,  0]]
        )

DEFAULT_MAP = np.array(
            [[0,    0,    0,   0,   0,  0,   0,   0,   0,   0,   0],
             [0.7,  0,    0,   0,   0,  0,   0,   0,   0,   0,   0],
             [-10,  8.2,  0,   0,   0,  0,   0,   0,   0,   0,   0],
             [-10, -10, 11.5,  0,   0,  0,   0,   0,   0,   0,   0],
             [-10, -10, -10, 14.0, 15.1,16.1,0,   0,   0,   0,   0],
             [-10, -10, -10, -10, -10, -10,  0,   0,   0,   0,   0],
             [-10, -10, -10, -10, -10, -10,  0,   0,   0,   0,   0],
             [-10, -10, -10, -10, -10, -10, 19.6, 20.3,0,   0,   0],
             [-10, -10, -10, -10, -10, -10, -10, -10,  0,   0,   0],
             [-10, -10, -10, -10, -10, -10, -10, -10, 22.4, 0,   0],
             [-10, -10, -10, -10, -10, -10, -10, -10, -10, 23.7, 0],
             [0,    0,    0,   0,   0,  0,   0,   0,   0,   0,  -10]]
        )

MAP = MINI_MAP if USE_MINI_MAP else DEFAULT_MAP

ele_per_row = 6 if USE_MINI_MAP else 11


def regularizer(x):
    return x - cp.log(x)


def occupancy(states, actions, is_terminals, gamma, max_timesteps, state_num, action_num, virtual_ratio):
    Pi = policy(states, actions, is_terminals, state_num, action_num)
    result = torch.zeros((state_num, action_num), dtype=torch.float)
    P = torch.zeros((state_num, max_timesteps + 1), dtype=torch.float)
    t = 0
    for i in range(len(states)):
        P[state_encode(states[i])][t] += 1
        if is_terminals[i]:
            t = 0
        else:
            t += 1
    P /= P.sum()
    P[P != P] = 0.0
    t = 0
    for t in range(max_timesteps + 1):
        result += gamma**t * P.T[t].unsqueeze(1)
    result *= Pi

    return result


def P(states, actions, is_terminals, state_num, action_num):
    result = torch.zeros((state_num, action_num, state_num, action_num), dtype=torch.float)
    for i in range(len(states)):
        if is_terminals[i] or is_terminals[i + 1]:
            continue
        result[state_encode(states[i])][actions[i]][state_encode(states[i + 1])][actions[i + 1]] += 1
    result = result.reshape(state_num * action_num, state_num * action_num)
    result /= result.sum(axis=1).unsqueeze(1)
    result[result != result] = 0.0  # in case of 0/0 = Nan
    return result

def policy(states, actions, is_terminals, state_num, action_num):
    result = torch.zeros((state_num, action_num), dtype=torch.float)
    for i in range(len(states)):
        result[state_encode(states[i].to(torch.device('cpu')))][actions[i].to(torch.device('cpu'))] += 1
    result /= result.sum(axis=1).unsqueeze(1)
    result[result != result] = 0.0  # in case of 0/0 = Nan
    return result


def state_encode(state):
    state = torch.squeeze(state)
    return int((state[0] * ele_per_row + state[1]).item())


def state_decode(encoded_state):
    state = np.array([math.floor(encoded_state / ele_per_row), encoded_state % ele_per_row])
    state = torch.FloatTensor(state.reshape(1, -1)).to(torch.device('cpu'))
    return state

def mean_reward(mean_r, admm_r, t):
    return mean_r + (admm_r - mean_r) / t

def visualize(reward_signal, avg_reward, avg_length, preference, path):
    plt.figure()
    img = plt.imread('submarine.png')
    height, width, _ = img.shape

    he = MAP.shape[0]
    wi = MAP.shape[1]
    grid_height = height // he
    grid_width = width // wi

    f_value = reward_signal @ preference
    diff_rewards = np.zeros((he, wi, 2))
    downs = np.zeros(2 * he * wi)
    for i in range(he):
        for j in range(wi):
            if i != (he - 1):
                down = f_value[i * wi + j][1] - f_value[(i + 1) * wi + j][0]
            else:
                down = f_value[i * wi + j][1]
            if j != (wi - 1):
                right = f_value[i * wi + j][3] - f_value[i * wi + (j+1)][2]
            else:
                right = f_value[i * wi + j][3]
            diff_rewards[i][j][0] = down
            diff_rewards[i][j][1] = right


    cmap = plt.cm.get_cmap('seismic')
    tmp = np.absolute(np.copy(diff_rewards))
    tmp[-1] = 0
    tmp = np.where(np.tile(np.expand_dims(MAP, 2), 2) == 0.0, tmp, 0)
    limited_max = tmp.max()
    diff_rewards = np.where(diff_rewards>limited_max, limited_max, diff_rewards)
    diff_rewards = np.where(diff_rewards<-limited_max, -limited_max, diff_rewards)
    vmin = np.min(diff_rewards)
    vmax = np.max(diff_rewards)
    arr_range = np.ptp(diff_rewards)
    normalized_arr = (diff_rewards - vmin) / arr_range
    colors = cmap(normalized_arr)

    for i in range(he):
        for j in range(wi):
            x = i * grid_height
            y = j * grid_width
            next_x = (i+1) * grid_height
            next_y = (j+1) * grid_width
            plt.hlines(next_x, y, next_y, colors[i][j][0])
            plt.vlines(next_y, x, next_x, colors[i][j][1])

    plt.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xticks([])
    plt.yticks([])
    plt.title('Agent'+path[path.rfind('\\')+1:], fontsize=18, fontweight="bold")
    plt.xlabel('Avg_reward: ' + str(round(avg_reward, 2)) + ', Avg_length: ' + str(round(avg_length, 2)), fontsize=16)
    plt.savefig(path + '.png')
    plt.close()