import numpy as np

# this file contains all the code necesary to evaluate a genome

directions = np.array([[1, 2], [2, 1], [-1, 2], [-2, 1], [1, -2], [2, -1], [-1, -2], [-2, -1]])
grid_w = 8
grid_h = 8

 # mutation, returns a copy
def mutate(sol):
    length = sol.shape[0]
    newSol = sol.copy()
    r = np.random.rand()
    if r <= 0.4:
        i = np.random.randint(2, length)
        newSol[i] = (newSol[i] + 1) % 8
    elif r <= 0.8:
        #change a random position to a new random number
        i = np.random.randint(2, length)
        newSol[i] = np.random.randint(0, 8)
    elif r <= 0.95:
        #reset starting position
        x = np.random.randint(0, 8)
        y = np.random.randint(0, 8)

        newSol[0] = x
        newSol[1] = y
    else:
        newSol = np.random.randint(0, 8, (length,))

    return newSol

def get_hand_metrics_and_obs(sol):
    moves = np.concatenate((np.array([[sol[0], sol[1]]]), directions[sol[2:]]), axis=0)
    visited = np.cumsum(moves, axis=0)

    valid_visited = []  
    # slow way for now, speed up later
    for i in visited:
        coord = tuple(i)
        if coord in valid_visited:
            #print("already visited")
            break #already seen
        elif coord[0] >= grid_w or coord[0] < 0 or coord[1] >= grid_h or coord[1] < 0:
            #print("Out of bounds")
            break #out of bounds
        else:
            valid_visited.append(coord)

    valid_visited = np.array(valid_visited)

    pos_visited = np.zeros((grid_w, grid_h))

    #get the rows and columns that were vsited validely
    rows_visited = np.zeros(grid_w)
    cols_visited = np.zeros(grid_h)

    for i in valid_visited:
        pos_visited[i[0], i[1]] = 1
        rows_visited[i[0]] += 1
        cols_visited[i[1]] += 1

    #get the number of rows and columns that were visited
    num_rows_visited = np.sum(rows_visited.astype(bool))
    num_cols_visited = np.sum(cols_visited.astype(bool))

    #get the rows that could have been visited inclusive of when rules were broken
    rows_visited_nonvalid = np.zeros(grid_w)
    cols_visited_nonvalid = np.zeros(grid_h)

    for i in visited:
        if i[0] < grid_w and i[0] >= 0 and i[1] < grid_h and i[1] >= 0:
            rows_visited_nonvalid[i[0]] += 1
            cols_visited_nonvalid[i[1]] += 1

    num_rows_visited_nonvalid = np.sum(rows_visited_nonvalid.astype(bool))
    num_cols_visited_nonvalid = np.sum(cols_visited_nonvalid.astype(bool))

    #get the number of tl to br diagonals that are visited
    total_diags = grid_w + grid_h - 1
    visited_tlbr = np.zeros(total_diags)
    visited_trbl = np.zeros(total_diags)

    for i in valid_visited:
        tlbr_diag = i[0]- i[1] + grid_w - 1
        trbl_diag = i[0] + i[1]

        visited_tlbr[tlbr_diag] += 1
        visited_trbl[trbl_diag] += 1
    
    num_tlbr_diags = np.sum(visited_tlbr.astype(bool))
    num_trbl_diags = np.sum(visited_trbl.astype(bool))

    #print(f"num visited rows {num_rows_visited}")
    #print(f"num visited cols {num_cols_visited}")
    #print(f"num visited rows nonvalid {num_rows_visited_nonvalid}")
    #print(f"num visited cols nonvalid {num_cols_visited_nonvalid}")
    #print(f"num visited tlbr diags {num_tlbr_diags}")
    #print(f"num visited trbl diags {num_trbl_diags}")

    endx = valid_visited[-1][0]
    endy = valid_visited[-1][1]

    measures =  np.array([num_rows_visited, num_cols_visited, num_tlbr_diags, num_trbl_diags, num_rows_visited_nonvalid, num_cols_visited_nonvalid, endx, endy])
    return measures, len(valid_visited), pos_visited.flatten(), rows_visited, cols_visited, visited_trbl, visited_tlbr

# given an nxn matrix, can split it into levels of granularity of summation.
# level 0 is just a summation of the arr,
# level 1 is summations of each quarter of the matrix
# level 2 is summations of each sixteenth.
#using a level.5 is half of one dimension and not the other
#returns a vector representing each quadrant
def get_subaggregation(matrix, num_objectives, deagg):
    assert len(matrix.shape) == 2, "Input must be a 2-dimensional matrix"
    assert matrix.shape[0] == matrix.shape[1], "Input must be a square matrix"

    if deagg == "time":
        n = int(np.sum(matrix))
        #make a vector of size of matrix with n leading ones
        scores = np.zeros((matrix.shape[0]*matrix.shape[1]))
        scores[:n] = 1

        step = (matrix.shape[0] * matrix.shape[1]) // num_objectives
        sums = np.zeros((num_objectives))
        for i in range(num_objectives):
            sums[i] = np.sum(scores[i*step:(i+1)*step])

    elif deagg == "space":
        n = matrix.shape[0]

        if num_objectives == 1:
            return [np.sum(matrix)]
            
        sums = []
        step_x = n // int(np.sqrt(num_objectives)) 
        step_y = n //int(np.sqrt(num_objectives))

        for i in range(0, n, step_x):
            for j in range(0, n, step_y):
                section = matrix[i:i+step_x, j:j+step_y]
                sums.append(np.sum(section))
    else:
        print("DEAGG NOT VALID")
        return None

    assert len(sums) == num_objectives
    return sums

def hamming_distance_matrix(vectors):
    n = vectors.shape[0]
    m = vectors.shape[1]
    dist_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(i+1, n):
            dist_matrix[i, j] = np.count_nonzero(vectors[i] != vectors[j])

    return dist_matrix + dist_matrix.T

#gets the population diversity from behaviors
def get_pop_diversity(visited):
    dist_matrix = hamming_distance_matrix(visited)
    return np.mean(dist_matrix)

# print(get_hand_metrics_and_obs([0, 0, 1, 2, 3, 1, 1, 1, 1, 1, 1]))