import torch
import numpy as np
from modules_ot.solvers import PSGDSolver, PSAdaGradSolver, PSAdamSolver 

class OTProblem:
    def __init__(self, dim, draw_mu, nu_points, nu_weights, cost="quadra", max_cost=1, theoretical_g_opt=None):
        self.dim = dim
        self.draw_mu = draw_mu 
        self.nu_points = nu_points  
        self.nu_weights = nu_weights 
        self.theoretical_g_opt = theoretical_g_opt  
        self.solver = PSGDSolver()
        self.g_approx = None
        self.f_approx = None
        self.epsilon = 0
        self.cost = cost
        self.max_cost = max_cost

    def print_setting(self): 
        print(f"Dimension: {self.dim}")
        print(f"Source measure: {self.draw_mu}")
        print(f"Target measure's has : {self.nu_points.shape[0]} points")
        # check if the weights are uniform
        if torch.all(torch.eq(self.nu_weights, self.nu_weights[0])):
            print(f"Target measure's weights are uniform: {self.nu_weights[0]}")
        else:
            print(f"Target measures's weights are between {self.nu_weights.min()} and {self.nu_weights.max()}")
        print(f"Theoretical g_opt: {self.theoretical_g_opt}")
        print(f"Approximated g_opt: {self.g_approx}")
        print(f"Solver: {self.solver}")
        print(f"Max cost: {self.max_cost}")


    def _select_solver(self, solver_name):
        if solver_name == "pasgd":
            self.solver = PSGDSolver()
        else:
            raise ValueError(f"Solver {solver_name} not implemented")


    def solve(self):
        print("Solving...")
        results = self.solver.solve(self.dim, self.draw_mu, self.nu_points, self.nu_weights, self.cost, self.max_cost)
        self.g_approx = results["g_bar"]
        print("Done.")
        return results

    def change_solver_parameters(self, params):
        self.solver.change_parameters(params)

    def cost_function(self, x, y):
        if self.cost == "quadra":
            return  0.5*torch.sum((x - y) **2, dim=-1)
        elif self.cost == "1.5": 
            return  torch.sum((x - y) **2, dim=-1)**(0.75)
        else: 
            raise ValueError(f"Unsupported cost type: {self.cost}. Supported values are 'quadra' and '1.5'.")


    def eot_c_transform(self, x, g):
        costs = self.cost_function(x, self.nu_points)
        if self.epsilon == 0:
            return torch.min(costs - g)
        else:
            print('not implemented')
            

    def transport_point(self, x, g):
        v = self.cost_function(self.nu_points, x) - g
        i_min = torch.argmin(v)
        return self.nu_points[i_min]

    def transport_error_point(self, x, g):
        theoretical_y = self.transport_point(x, self.theoretical_g_opt)
        approx_y = self.transport_point(x, g)
        return torch.norm(theoretical_y - approx_y)**2

    def transport_error(self, g):
        # nb_montecarlo = 10000
        # x = self.draw_mu(nb_montecarlo, self.dim)
        # errors = np.array([self.transport_error_point(xi, g) for xi in x])
        # return np.mean(errors)
        nb_montecarlo = 10000
        x = self.draw_mu(nb_montecarlo, self.dim)  # Generate random samples and convert to tensor
        g = torch.tensor(g).float()  # Ensure g is a tensor
        
        # Theoretical transport points (vectorized)
        theoretical_y = self.transport_point_batch(x, self.theoretical_g_opt)

        # Approximate transport points (vectorized)
        approx_y = self.transport_point_batch(x, g)

        # Compute squared errors
        errors = torch.norm(theoretical_y - approx_y, dim=1)**2  # Vectorized norm computation
        return torch.mean(errors).item()

    def transport_point_batch(self, x_batch, g):
        # Vectorized computation of cost function
        cost_matrix = self.cost_function(self.nu_points, x_batch.unsqueeze(1))  # Shape: (batch_size, num_nu_points)
        v = cost_matrix - g.unsqueeze(0)  # Broadcast g to batch_size
        i_min = torch.argmin(v, dim=1)  # Index of minimum cost for each x in batch
        return self.nu_points[i_min]  # Gather corresponding transport points
