import torch
import numpy as np

def draw_uniform_cube(batch_size, d,  precision = torch.float64):
    return torch.rand(batch_size, d, dtype=precision)


def example_1(m = 100, d = 3): 
    """
    Example 1 illustrates an optimal vector for the case where \mu is uniform on [0,1]^d and 
    \nu has his first coordinate uniform on [1/(2m), 3/(2m), ..., (2m-1)/(2m)] and the remaining coordinates are 0.5.
    This example is the same as in : @inproceedings{pooladian2023minimax,
        title={Minimax estimation of discontinuous optimal transport maps: The semi-discrete case},
        author={Pooladian, Aram-Alexandre and Divol, Vincent and Niles-Weed, Jonathan},
        booktitle={International Conference on Machine Learning},
        pages={28128--28150},
        year={2023},
        organization={PMLR}
    }
    

    """
    # Define the target measure
    nu_points = torch.zeros((m, d))
    # Set the first coordinate for each atom
    for j in range(m):
        nu_points[j, 0] = (j + 0.5) / m
    # Set the remaining coordinates to 0.5
    nu_points[:, 1:] = 0.5
    nu_weights = torch.ones(m) / m

    # g_opt_theoric
    g_opt = torch.zeros(m)

    # define the sample function for the source
    def draw_samples_mu(batch_size, d,  precision = torch.float64):
        return torch.rand(batch_size, d, dtype=precision)

    max_cost = 1

    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt

def example_2():
    """
    Examples 2 illustrates a less symmetric problem, where the weights of \nu are not uniform.
    The diracs of \nu are generated randomly (in generate_ex2.py).
    \mu is fixed to be the unform measure on [0,1]^d with d = 10.
    We then fix a g_opt randomly, and calculate the weights of \nu corresponding to the Laguerre cells L(g_opt)
    """
    d = 10
    def draw_samples_mu(batch_size, d, precision = torch.float64):
        return torch.rand(batch_size, d, dtype=torch.float64)

    #load the generated data 
    nu_points = np.load("nu_points_ex2.npy")
    nu_weights = np.load("nu_weights_ex2.npy")
    g_opt = np.load("g_opt_ex2.npy")

    #convert to pytorch
    nu_points = torch.tensor(nu_points, dtype=torch.float64)
    nu_weights = torch.tensor(nu_weights, dtype=torch.float64)
    g_opt = torch.tensor(g_opt, dtype=torch.float64)

    max_cost = 1
    
    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt
def example_2_old():
    """
    Examples 2 illustrates an optimal vector, where \mu is uniform on [0,1]^d with d = 10
    \nu was generated randomly (in generate_ex2.py).
    g_opt was then approximated using drpasgd with 10**8 iterations. 
    """
    d = 10
    draw_samples_mu = draw_uniform_cube
    #load the generated data 
    nu_points = np.load("nu_points_ex2.npy")
    nu_weights = np.load("nu_weights_ex2.npy")
    g_opt = np.load("g_opt_ex2.npy")

    #convert to pytorch 
    nu_points = torch.tensor(nu_points, dtype=torch.float64)
    nu_weights = torch.tensor(nu_weights, dtype=torch.float64)
    g_opt = torch.tensor(g_opt, dtype=torch.float64)

    max_cost = 1
 
    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt

def example_3(m=20, delta= 0.1):
    """
    Example 3 illustrates the optimal vector as in the proof of Theorem 4 : Minimax estimation of the discrete potential. 
    In this case: \nu is uniform on (1/m, 2/m, ..., 1) and \mu is uniform on [0+delta,1+delta]. 
    In this case, g* is optimal iff g*_{j+1} - g*_j = 1/2m^2 - delta/m for j=1,...,m-1.

    args: 
    m: number of atoms of the target measure \nu
    delta: parameter for the source measure \mu unfiorm on [delta, 1+delta]
    """
    d = 1 

    def draw_uniform_cube_shifted(batch_size, d, precision = torch.float64):
        return delta + torch.rand(batch_size, d, dtype=precision)

    #define the target measure k/m for k=1,...,m
    nu_point = torch.linspace(1/m, 1, m).reshape(m, 1)
    nu_weights = torch.ones(m) / m
    #g_opt_j - g_opt_{j-1} = 1/2m^2 - delta/m
    g_opt = torch.zeros(m)
    for j in range(1, m):
        g_opt[j] = g_opt[j-1] + 1/(2*m**2) - delta/m
    max_cost = 1 + delta 
    return d, draw_uniform_cube_shifted, nu_point, nu_weights, max_cost, g_opt
    


