import numpy as np
import torch


def Dv(v, K):
    '''
    Computes D @ v, where D is the blocked difference matrix much more quickly
    '''
    v2 = v.reshape(K, -1)
    v3 = np.hstack((v2[:, 0:1], np.diff(v2, axis=1)))
    v4 = v3.flatten()
    return v4

def Dv_torch(v, K):
    '''
    Computes D @ v, where D is the blocked difference matrix much more quickly
    '''
    v2 = v.reshape(K, -1)
    v3 = torch.hstack((v2[:, 0:1], torch.diff(v2, axis=1)))
    v4 = v3.flatten()
    return v4


def normalize(vals):
    """
    normalize to (0, max_val)
    input:
      vals: 1d array
    """
    min_val = np.min(vals)
    max_val = np.max(vals)
    return (vals - min_val) / (max_val - min_val)


def create_joint_maps(individual_map1, individual_map2, diff_maps, width=None,height=None):
    """
    INPUT:
        individual_map1: (2 x N_STATES) - m(s1)
        individual_map2: (2 x N_STATES) - n(s1)
        diff_maps: (N_MAPS x N_diff) - \phi(abs(s1-s2))
    OUTPUT:
        joint_map: ((N_MAPS+4) X N_STATES**2) 
    """
    K1,N = individual_map1.shape
    K2,N = individual_map2.shape
    if width == None and height == None:
        height, width = int(np.sqrt(N)), int(np.sqrt(N))
    diff_square = np.unique([i**2+j**2 for i in range(height) for j in range(width)])
    K3,diff_N = diff_maps.shape
    assert diff_square.shape[0] == diff_N, 'incampitible size between diff_maps and individual_maps'
    joint_maps = np.zeros((K1+K2+K3, N**2))
    joint_maps[:K1,:] = np.repeat(individual_map1, N, axis=1)
    joint_maps[K1:(K1+K2),:] = np.tile(individual_map2, N)
    for i, diff_sq in enumerate(diff_square):
        idx = idx = get_diff_idx(diff_sq, width,height)
        joint_maps[(K1+K2):(K1+K2+K3),idx] = diff_maps[:,i]
    return joint_maps



def create_joint_maps_pytorch(individual_map1, individual_map2, diff_maps, width=None,height=None):
    """
    INPUT:
        individual_map1: (K1 x N_STATES) - m(s1)
        individual_map2: (K2 x N_STATES) - n(s1)
        diff_maps: (N_MAPS x N_diff) - \phi(abs(s1-s2))
    OUTPUT:
        joint_map: ((N_MAPS+K1+K2) X N_STATES**2) 
    """
    K1, N = individual_map1.shape
    K2, N = individual_map2.shape
    if width == None and height == None:
        height, width = int(torch.sqrt(torch.tensor(N))), int(torch.sqrt(torch.tensor(N)))
    diff_square = torch.unique(torch.tensor([i**2+j**2 for i in range(height) for j in range(width)]))
    K3, diff_N = diff_maps.shape
    assert diff_square.shape[0] == diff_N, 'incompatible size between diff_maps and individual_maps'
    joint_maps = torch.zeros((K1+K2+K3, N**2))
    joint_maps[:K1, :] = individual_map1.repeat_interleave(N, dim=1)
    joint_maps[K1:(K1+K2), :] = individual_map2.repeat(1, N)
    for i, diff_sq in enumerate(diff_square):
        idx = get_diff_idx(diff_sq, width, height)
        joint_maps[(K1+K2):(K1+K2+K3), idx] = diff_maps[:, i]
    return joint_maps


def get_diff_idx(diff_sq,width,height):
    """
    Compute the index in a joint space vector using difference measure on a grid
    INPUT:
        diff_sq: square of grid difference, computed as Euclidean difference
        N: number of states in a single agent space (i.e. height x width)
    OUTPUT:
        an index array such that grid_diff(s(i//N), s(i%N)) = diff
    """
    N = width * height
    idx = []
    for i in range(N**2):
        s1_idx, s2_idx = i // N, i % N
        y1,x1,y2,x2 = s1_idx % height, s1_idx // height, s2_idx % height, s2_idx // height
        if (y1-y2)**2+(x1-x2)**2 == diff_sq:
            idx.append(i)
    return idx




def collapse_permutation_matrix(P_a, i):
    '''
    collapse the joint permutation matrix to a single agent
    INPUT:
        i (= 1 or 2): agent index to keep
    '''
    assert i in [1,2], 'error in agent index'
    n1,n2 = P_a.shape
    N_STATES, N_ACTIONS = int(np.sqrt(n1)), int(np.sqrt(n2))
    if i == 2:
        P_a_single = (P_a % N_STATES)[:N_STATES,:N_ACTIONS]
    elif i == 1:
        P_a_single = (P_a // N_STATES)[::N_STATES,::N_ACTIONS]
    return P_a_single


def collapse_permutation_matrix_pytorch(P_a, i):
    '''
    collapse the joint permutation matrix to a single agent
    INPUT:
        i (= 1 or 2): agent index to keep
    '''
    assert i in [1,2], 'error in agent index'
    n1,n2 = P_a.shape
    N_STATES, N_ACTIONS = int(torch.sqrt(torch.tensor(n1))), int(torch.sqrt(torch.tensor(n2)))
    if i == 2:
        P_a_single = (P_a % N_STATES)[:N_STATES,:N_ACTIONS]
    elif i == 1:
        P_a_single = (P_a // N_STATES)[::N_STATES,::N_ACTIONS]
    return P_a_single


def collapse_reward(reward,agent):
    '''
    collapse the joint reward to a single agent
    INPUT:
        reward: (N_STATES**2 x 1)
        agent (= 1 or 2): agent index to keep
    '''
    N_STATES = int(np.sqrt(reward.shape[0]))
    reward_single = np.zeros((N_STATES,1))
    if agent == 1:
        for i in range(N_STATES):
            reward_single[i,:] = np.sum(reward[i*N_STATES:(i+1)*N_STATES,:])
    elif agent == 2:
        for i in range(N_STATES):
            reward_single[i,:] = np.sum(reward[i::N_STATES,:])
    return reward_single


def collapse_reward_pytorch(reward,agent):
    '''
    collapse the joint reward to a single agent
    INPUT:
        reward: (N_STATES**2 x 1)
        agent (= 1 or 2): agent index to keep
    '''
    N_STATES = int(torch.sqrt(torch.tensor(reward.shape[0])))
    reward_single = torch.zeros((N_STATES,1))
    if agent == 1:
        for i in range(N_STATES):
            reward_single[i,0] = torch.sum(reward[i*N_STATES:(i+1)*N_STATES,0])
    elif agent == 2:
        for i in range(N_STATES):
            reward_single[i,0] = torch.sum(reward[i::N_STATES,0])
    return reward_single