import numpy as np
from scipy.linalg import null_space
from joblib import Parallel, delayed

##### generate training and teaching dataset

# generate a training dataset that contains both label
def generate_nature_dataset(w_star, n_train, n_test):
    
    d = w_star.shape[1]
    
    # Generate n points from a normal distribution
    X_train = np.random.randn(n_train, d)
    X_test = np.random.randn(n_test, d)
    
    y_train = (np.matmul(X_train, w_star.T) >= 0).astype(int).flatten()
    y_test = (np.matmul(X_test, w_star.T) >= 0).astype(int).flatten()
    
    return X_train, y_train, X_test, y_test

def compute_teaching_dataset(w_star, teaching_budget, margin=0):

    d = w_star.shape[1]

    ##### orthobasis vectors of w_star
    X_ortho = null_space(w_star.reshape(1, -1)).T   # d-1 x d
    
    # compute teaching dataset with teaching_budget size
    if(teaching_budget==0):
        return None, None
    elif(teaching_budget==1):
        X_teach = X_ortho[0,:]
        y_teach = np.ones(1)
    elif(teaching_budget>=2 and teaching_budget<=d):
        X_teach = X_ortho[:teaching_budget-1,:]
        X_teach = np.vstack((X_teach+margin*w_star, -np.sum(X_teach, axis=0)+margin*w_star))
        y_teach = np.ones(teaching_budget)
    elif(teaching_budget>=d+1):
        X_teach = X_ortho[:d-1,:] 
        X_teach = np.vstack((X_teach+margin*w_star, -np.sum(X_teach, axis=0)+margin*w_star))
        X_teach = np.vstack((X_teach, -w_star))
        y_teach = np.hstack((np.ones(d), np.zeros(1)))

    return X_teach, y_teach


def join_dataset(X_tuples, y_tuples):

    # stack the two dataset X_tuples = (X1, X2, X3, X4), y_tuples = (y1, y2, y3, y4)
    X = np.vstack(X_tuples)
    y = np.hstack(y_tuples)
    return X, y


def generate_nd_sphere_grid(dim, resolution, round_off=10):
    """
    Generate a grid of points on an N-dimensional sphere.

    Parameters:
        dim (int): The dimensionality of the sphere.
        resolution (int): The number of points per dimension.

    Returns:
        ndarray: Array of shape (num_points, dim) representing the grid points.
    """
    # Create arrays for each angle in spherical coordinates
    angles = [np.linspace(0, np.pi, resolution) for _ in range(dim - 1)]
    angles[-1] = np.linspace(0, 2 * np.pi, resolution)  # Full range for azimuthal angle

    # Generate a meshgrid for all angles
    grids = np.meshgrid(*angles, indexing='ij')

    # Convert spherical coordinates to Cartesian coordinates
    coords = []
    for i in range(dim):
        coord = np.ones_like(grids[0])  # Start with ones
        for j in range(i):
            coord *= np.sin(grids[j])  # Multiply by sin of previous angles
        if i < dim - 1:
            coord *= np.cos(grids[i])  # Multiply by cos of the current angle
        coords.append(coord)

    # Stack Cartesian coordinates and reshape into a list of points
    cartesian_coords = np.stack(coords, axis=-1).reshape(-1, dim)

    # round the points to 10 decimal places
    cartesian_coords = np.round(cartesian_coords, round_off)
    
    return cartesian_coords

# # Example usage
# dim = 2  # Dimensionality of the sphere (e.g., 3 for a standard sphere)
# resolution = 5  # Number of divisions per angle
# sphere_points = generate_nd_sphere_grid(dim, resolution)

# print(f"Generated {sphere_points.shape[0]} points on a {dim}D sphere.", sphere_points.shape)
# sphere_points



# define a grid of w in d dimension between -1 and 1 and k splits along each dimension

def create_grid(d, k):
    # Create a 1D array of k evenly spaced values between -1 and 1
    linspace = np.linspace(-1, 1, k)
    
    # Create a grid in d dimensions
    grid = np.meshgrid(*[linspace]*d)
    
    # Reshape the grid to a list of points
    grid_points = np.vstack([g.ravel() for g in grid]).T
    
    return grid_points

# # define a grid of w in d dimension unit sphere with k splits along each angle

# def create_sphere_grid(d, k):
#     # Create a 1D array of k evenly spaced values between 0 and 2pi
#     linspace = np.linspace(0, 2*np.pi, k)
    
#     # Create a grid in d dimensions
#     grid = np.meshgrid(*[linspace]*d)
    
#     # Reshape the grid to a list of points
#     grid_points = np.vstack([np.cos(g.ravel()) for g in grid]).T
    
#     return grid_points


def sample_iid_weights(d, n):
    # Sample n weights from standard gaussian distribution in d dimensions
    w = np.random.randn(n, d)
    # round off to 10 decimal places
    return w


# create a version space of w_star based on X_train, y_train
def evaluate_version_space(n_teach, n_env, repeat_index, X_train, y_train, X_test, y_test, grid_points, sim_id=None):

    # Initialize a list to store the version space
    version_space = {}

    # Iterate over all points in the grid
    for w in grid_points:
        # Check if the current w is consistent with all training points
        if np.all((np.matmul(X_train, w.T).flatten() >= 0).astype(int) == y_train):
            # evaluate accuracy on test set
            risk = np.mean((np.matmul(X_test, w.T).flatten() >= 0).astype(int) != y_test)
            version_space[tuple(w)] = risk

    if(sim_id is None):
        return (n_teach, n_env, repeat_index), version_space
    else:
        return (n_teach, n_env, sim_id, repeat_index), version_space


# training nature only (old version)
def train_nature_only(w_star, max_total_budget, w_set, n_test=1000, repeat_runs=5):

    n_train_list = np.arange(1, max_total_budget+1)
    worst_case_risk = {}

    for n_train in tqdm.tqdm(n_train_list):
        
        worst_case_risk[n_train] = []

        for k in range(repeat_runs):

            X_train, y_train, X_test, y_test = generate_nature_dataset(w_star, n_train, n_test)
            
            # evaluate a version space of w_star based on X_train, y_train
            version_space = evaluate_version_space(X_train, y_train, X_test, y_test, w_set)
            worst_case_risk[n_train].append(max(version_space.values()))

    return worst_case_risk

# training nurture then nature (old version)
def train_nurture_then_nature(w_star, max_total_budget, w_set, teaching_budget, n_test=1000, repeat_runs=5):

    worst_case_risk = {}
    
    for n_total in tqdm.tqdm(np.arange(teaching_budget, max_total_budget+1)):

        worst_case_risk[n_total] = []
        if(n_total==teaching_budget):
            X_train, y_train = compute_teaching_dataset(w_star, n_total)
            _, _ , X_test, y_test = generate_nature_dataset(w_star, 1, n_test)
            
            # evaluate a version space of w_star based on X_train, y_train
            version_space = evaluate_version_space(X_train, y_train, X_test, y_test, w_set)
            worst_case_risk[n_total].append(max(version_space.values()))
            
        else:
            n_train = n_total - teaching_budget
            
            # compute a teaching dataset
            X_teach, y_teach = compute_teaching_dataset(w_star, teaching_budget)

            for k in range(repeat_runs):
                # sample training dataset
                X_train, y_train, X_test, y_test = generate_nature_dataset(w_star, n_train, n_test)
                
                # join the training dataset with the teaching dataset
                X_train, y_train = join_dataset([X_train, X_teach], [y_train, y_teach])

                # evaluate a version space of w_star based on X_train, y_train
                version_space = evaluate_version_space(X_train, y_train, X_test, y_test, w_set)
                worst_case_risk[n_total].append(max(version_space.values()))

    return worst_case_risk

# training nurture then nature (new version) - can handle nature only case as well and is parallelized
def train_nurture_then_nature_parallel(n_teach_list, n_env_list, w_star, w_set, n_test=1000, repeat_runs=5, total_threads=4):

    args_list = []
    for n_teach in n_teach_list:
        
        # generate teaching dataset
        X_teach, y_teach = compute_teaching_dataset(w_star, n_teach)

        for n_env in n_env_list:
            
            if(n_teach == 0 and n_env == 0):
                continue

            for repeat_index in range(repeat_runs):
                
                if(n_env>0):
                    # generate training and testing dataset
                    X_env, y_env, X_test, y_test = generate_nature_dataset(w_star, n_env, n_test)

                    if(n_teach>0):
                        # join the datasets
                        X_train, y_train = join_dataset([X_teach, X_env], [y_teach, y_env])
                    else:
                        X_train, y_train = X_env, y_env
                else:
                    X_train, y_train = X_teach, y_teach
                    _, _, X_test, y_test = generate_nature_dataset(w_star, 1, n_test)

                args_list.append((n_teach, n_env, repeat_index, X_train, y_train, X_test, y_test, w_set))
    

    total_runs = len(args_list)

    print("Total number of runs : ", total_runs)

    batch_start_index_list = np.arange(0, total_runs, total_threads)

    all_output_list = []

    for batch_id, batch_start_index in enumerate(batch_start_index_list):

        print('*'*75)
        current_index_range = str(batch_start_index)+':'+str(batch_start_index+total_threads)
        print("Current process batch id : {0:2d}, batch index range : {1}".format(batch_id, current_index_range))
        print('*'*75)
        
        output_list = Parallel(n_jobs=total_threads)(delayed(evaluate_version_space)(*args) for args in args_list[batch_start_index:batch_start_index+total_threads])
        all_output_list.extend(output_list)

    return all_output_list