import random
from turtle import st
import einops
import numpy as np


def extract_query_mat(x, num_slots):
    
    batch_size, num_patches, dim = x.shape
    
    # change width of np prints
    # np.set_printoptions(linewidth=200)
    
    dots = np.einsum('b i d, b j d -> b i j', x, x)
    attn = softmax(dots / np.sqrt(dim), axis=-1)
    # print(f'attn before:\n{np.round(attn, 2)}')
    attn_with_eye = attn + np.eye(num_patches)[np.newaxis, :, :]
    # print(f'attn after:\n{np.round(attn, 2)}')
    
    indices_list = []
    start_idx = np.random.randint(0, num_patches - 1, size=batch_size)
    indices_list.append(start_idx)
    # print(f'new index: {indices_list[-1]}')
    cum_similarities = attn_with_eye[range(batch_size), indices_list[-1], :].copy()
    for i in range(num_slots - 1):
        # print(f'slot {i + 1}')
        # print(f'cum_similarities: {cum_similarities}')
        indices_list.append(np.argmin(cum_similarities, axis=-1))
        # print(f'new index: {indices_list[-1]}')
        cum_similarities += attn_with_eye[range(batch_size), indices_list[-1], :]
    
    # print(f'slot {i + 1}')
    # print(f'cum_similarities: {cum_similarities}')
    
    indices_np = np.stack(indices_list, -1)
    # print(f'indices_np: {indices_np}')
    select_mat = np.eye(num_patches)[indices_np]
    select_atn_mat = np.einsum('b c n, b n m -> b c m', select_mat, attn)
    return select_atn_mat
    

def softmax(x, axis=-1):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

def batched_extract_query_mat(x, num_slots):
    vec_func = np.vectorize(extract_query_mat, signature='(n, ...), () -> (n, ...)')
    return vec_func(x, num_slots)

def unbatched_extract_query_mat(x, num_slots):
    num_patches, dim = x.shape
    
    dots = np.einsum('i d, j d -> i j', x, x)
    attn = softmax(dots / np.sqrt(dim), axis=-1)
    print(f'attn: {attn}')
    
    indices = []
    start_idx = random.randint(0, num_patches - 1)
    indices.append(start_idx)
    # cum_similarities = dots[indices[-1]]
    cum_similarities = attn[indices[-1]].copy()
    for i in range(num_slots - 1):
        print(f'slot {i + 1}')
        print(f'cum_similarities: {cum_similarities}')
        indices.append(np.argmin(cum_similarities))
        print(f'new index: {indices[-1]}')
        cum_similarities += attn[indices[-1]]
    
    indices = np.array(indices).squeeze()
    select_mat = np.eye(num_patches)[indices]
    select_atn_mat = select_mat @ attn
    return select_atn_mat