import time
import sys

import matplotlib.pyplot as plt
import numpy as np
from tqdm import trange, tqdm

#from https://docs.pyribs.org/en/stable/tutorials/arm_repertoire.html
def get_qd_measures(solutions, link_lengths):   
    """Returns the objective values and measures for a batch of solutions.
    
    Args:
        solutions (np.ndarray): A (batch_size, dim) array where each row
            contains the joint angles for the arm. `dim` will always be 12
            in this tutorial.
        link_lengths (np.ndarray): A (dim,) array with the lengths of each
            arm link (this will always be an array of ones in the tutorial).
    Returns:
        objs (np.ndarray): (batch_size,) array of objectives.
        meas (np.ndarray): (batch_size, 2) array of measures.
    """
    objs = -np.var(solutions, axis=1)

    # theta_1, theta_1 + theta_2, ...
    cum_theta = np.cumsum(solutions, axis=1)
    # l_1 * cos(theta_1), l_2 * cos(theta_1 + theta_2), ...
    x_pos = link_lengths[None] * np.cos(cum_theta)
    # l_1 * sin(theta_1), l_2 * sin(theta_1 + theta_2), ...
    y_pos = link_lengths[None] * np.sin(cum_theta)

    meas = np.concatenate(
        (
            np.sum(x_pos, axis=1, keepdims=True),
            np.sum(y_pos, axis=1, keepdims=True),
        ),
        axis=1,
    )

    return objs, meas

def get_objectives(solutions, link_lengths, num_objectives=1, deagg="time"):   
    """Returns a series of deaggregated objectives for a batch of solutions.
    
    Args:
        solutions (np.ndarray): A (batch_size, dim) array where each row
            contains the joint angles for the arm. `dim` will always be 12
            in this tutorial.
        link_lengths (np.ndarray): A (dim,) array with the lengths of each
            arm link (this will always be an array of ones in the tutorial).
        num_objectives (int): The number of objectives to deaggregate into.
    Returns:
        objs (np.ndarray): (batch_size, num_objectives) array of objectives.
    """

    assert num_objectives > 0 #make sure we have at least one objective
    assert solutions.shape[1] % num_objectives == 0 #make sure we can split the solutions into equal parts

    #objs = -np.std(solutions, axis=1)

    # if objs is 1, we need to just return the std of the solutions
    # else, we split up the solutions and return a vector of the std of each section
    #padding = 1
    if (deagg == "time"):
        size_of_section = int(solutions.shape[1] / num_objectives)
        
        objs = np.zeros((solutions.shape[0], num_objectives))

        #for now, objs is the partial sum of the variance
        mean = np.mean(solutions, axis=1)

        for i in range(num_objectives):
            objs[:, i] = -np.sum((solutions[:, i*size_of_section:(i+1)*size_of_section] - mean[:, None])**2, axis=1)/solutions.shape[1]
            #objs[:, i] = np.sum(solutions[:, i*size_of_section:(i+1)*size_of_section], axis=1)/solutions.shape[1]
        #print(objs)

    elif (deagg == "space"):
        objs = np.ones((solutions.shape[0], int(np.sqrt(num_objectives)), int(np.sqrt(num_objectives)))) * (-100) #lowest possible score
        # theta_1, theta_1 + theta_2, ...
        cum_theta = np.cumsum(solutions, axis=1)
        # l_1 * cos(theta_1), l_2 * cos(theta_1 + theta_2), ...
        x_pos = link_lengths[None] * np.cos(cum_theta)
        x_pos_cum = np.cumsum(x_pos, axis=1)
        # l_1 * sin(theta_1), l_2 * sin(theta_1 + theta_2), ...
        y_pos = link_lengths[None] * np.sin(cum_theta)
        y_pos_cum = np.cumsum(y_pos, axis=1)

        #x, y are the cumulative positions of the arm positions
        size_of_section = int(solutions.shape[1] // num_objectives)

        max_arm_length = np.sum(link_lengths)
        mean = np.mean(solutions, axis=1)

        for i in range(num_objectives):
            #take the objectives to be the negative variance, placed at the position of the arm we are counting
            obj = -np.sum((solutions[:, i*size_of_section:(i+1)*size_of_section] - mean[:, None])**2, axis=1)/solutions.shape[1]
            
            #we place this obj at the position of the arm at this position
            #position ranges from -sum(link_lengths) to sum(link_lengths)
            x = x_pos_cum[:, (i+1)*size_of_section-1] # end position of section x
            y = y_pos_cum[:, (i+1)*size_of_section-1] # end position of section y
  
            x = (x + max_arm_length)/(2*(max_arm_length)) #normalize x to be positive and between 0 and 1
            x_index = (x * (int(np.sqrt(num_objectives)))).astype(int) #index of x

            y = (y + max_arm_length)/(2*(max_arm_length)) #normalize y to be positive and between 0 and 1
            y_index = (y * (int(np.sqrt(num_objectives)))).astype(int) #index of y

            objs[:, x_index, y_index] = obj[:, None]

        #print(f"objs: {objs}")
        #print(f"objs shape: {objs.shape}")
        objs= objs.reshape((solutions.shape[0], num_objectives)) 

    #print(objs)
    return objs

def get_objectives_space(solutions, link_lengths, num_objectives=1):   
    """Returns a series of deaggregated objectives for a batch of solutions.
    
    Args:
        solutions (np.ndarray): A (batch_size, dim) array where each row
            contains the joint angles for the arm. `dim` will always be 12
            in this tutorial.
        link_lengths (np.ndarray): A (dim,) array with the lengths of each
            arm link (this will always be an array of ones in the tutorial).
        num_objectives (int): The number of objectives to deaggregate into.
    Returns:
        objs (np.ndarray): (batch_size, num_objectives) array of objectives.
    """

    assert num_objectives > 0 #make sure we have at least one objective
    assert solutions.shape[1] % num_objectives == 0 #make sure we can split the solutions into equal parts

    #objs = -np.std(solutions, axis=1)

    # if objs is 1, we need to just return the std of the solutions
    # else, we split up the solutions and return a vector of the std of each section
    #padding = 1
    size_of_section = int(solutions.shape[1] / num_objectives)
    
    objs = np.zeros((solutions.shape[0], num_objectives))

    #for i in range(num_objectives-1):
        #objs[:, i] = -np.var(solutions[:, i*size_of_section:(i+1)*size_of_section + padding], axis=1)
    #no padding on last one
    #objs[:, num_objectives-1] = -np.var(solutions[:, (num_objectives-1)*size_of_section:], axis=1)

    #for now, objs is the partial sum of the variance
    mean = np.mean(solutions, axis=1)

    for i in range(num_objectives):
        objs[:, i] = -np.sum((solutions[:, i*size_of_section:(i+1)*size_of_section] - mean[:, None])**2, axis=1)/solutions.shape[1]
        #objs[:, i] = np.sum(solutions[:, i*size_of_section:(i+1)*size_of_section], axis=1)/solutions.shape[1]
    #print(objs)
    return objs

def visualize(solution, link_lengths, objective, ax):
    """Plots an arm with the given angles and link lengths on ax.
    
    Args:
        solution (np.ndarray): A (dim,) array with the joint angles of the arm.
        link_lengths (np.ndarray): The length of each link the arm.
        objective (float): The objective of this solution.
        ax (plt.Axes): A matplotlib axis on which to display the arm.
    """
    lim = 1.05 * np.sum(link_lengths)  # Add a bit of a border.
    ax.set_aspect("equal")
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)

    ax.set_title(f"Objective: {objective}")

    # Plot each link / joint.
    pos = np.array([0, 0])  # Starting position of the next joint.
    cum_thetas = np.cumsum(solution)
    for link_length, cum_theta in zip(link_lengths, cum_thetas):
        # Calculate the end of this link.
        next_pos = pos + link_length * np.array(
            [np.cos(cum_theta), np.sin(cum_theta)])
        ax.plot([pos[0], next_pos[0]], [pos[1], next_pos[1]], "-ko", ms=3)
        pos = next_pos

    # Add points for the start and end positions.
    ax.plot(0, 0, "ro", ms=6)
    final_label = f"Final: ({pos[0]:.2f}, {pos[1]:.2f})"
    ax.plot(pos[0], pos[1], "go", ms=6, label=final_label)
    ax.legend()

