import torch
from modules_ot.solvers import DRPASGDSolver, AdamFixedRegSolver, OnlineSinkhornSolver, DRAdagradSolver, LooplessSadaGradSolver
print("OTProblem module loaded.")

class OTProblem:
    def __init__(self, dim, draw_mu, nu_points, nu_weights, epsilon=0, max_cost=1, theoretical_g_opt=None):
        self.dim = dim
        self.draw_mu = draw_mu 
        self.nu_points = nu_points  
        self.nu_weights = nu_weights 
        self.theoretical_g_opt = theoretical_g_opt  
        self.solver = DRPASGDSolver()
        self.g_approx = None
        self.f_approx = None
        self.epsilon = epsilon
        self.max_cost = max_cost

    def print_setting(self): 
        print(f"Dimension: {self.dim}")
        print(f"Source measure: {self.draw_mu}")
        print(f"Target measure's has : {self.nu_points.shape[0]} points")
        # check if the weights are uniform
        if torch.all(torch.eq(self.nu_weights, self.nu_weights[0])):
            print(f"Target measure's weights are uniform: {self.nu_weights[0]}")
        else:
            print(f"Target measures's weights are between {self.nu_weights.min()} and {self.nu_weights.max()}")
        print(f"Theoretical g_opt: {self.theoretical_g_opt}")
        print(f"Approximated g_opt: {self.g_approx}")
        print(f"Solver: {self.solver}")
        print(f"Epsilon: {self.epsilon}")
        print(f"Max cost: {self.max_cost}")


    def _select_solver(self, solver_name):
        name = solver_name.lower()
        if name == "drpasgd":
            self.solver = solvers.DRPASGDSolver()
        elif name == "adam_fixed":
            self.solver = solvers.AdamFixedRegSolver()
        elif name == "online_sinkhorn":
            self.solver = solvers.OnlineSinkhornSolver()
        else:
            raise ValueError(f"Solver {solver_name} not implemented")

    def solve(self):
        print("Solving...")
        g_t, g_bar, g_hist, g_bar_hist, g_weighted_hist = self.solver.solve(self.dim, self.draw_mu, self.nu_points, self.nu_weights, self.max_cost)
        self.g_approx = g_bar
        print("Done.")
        return g_hist, g_bar_hist, g_weighted_hist

    def change_solver_parameters(self, params):
        self.solver.change_parameters(params)

    def cost_function(self, x, y):
        # x: [N, d], y: [M, d] → output: [N, M]
        x = x.unsqueeze(1)  # [N, 1, d]
        y = y.unsqueeze(0)  # [1, M, d]
        return 0.5 * torch.sum((x - y) ** 2, dim=-1)  # [N, M]

    def ot_map(self, x, g):
        # x: [N, d], g: [M]
        cost = self.cost_function(x, self.nu_points)  # [N, M]
        v = cost - g  # broadcast g: [N, M]
        i_min = torch.argmin(v, dim=1)  # [N]
        return self.nu_points[i_min]  # [N, d]

    def eot_map(self, x, g):
        # x: [N, d], g: [M]
        cost = self.cost_function(x, self.nu_points)  # [N, M]
        v = cost - g  # [N, M]
        weights = torch.nn.functional.softmax(-v / self.epsilon, dim=1)  # [N, M]
        return torch.matmul(weights, self.nu_points)  # [N, d]

    def transport_error_point(self, x, g, entropic=False):
        x = x.unsqueeze(0)  # [1, d] — ensure batched
        if entropic:
            approx_y = self.eot_map(x, g)  # [1, d]
        else:
            approx_y = self.ot_map(x, g)  # [1, d]
        theoretical_y = self.ot_map(x, self.theoretical_g_opt)  # [1, d]
        return torch.norm(theoretical_y - approx_y, p=2) ** 2  # scalar

    def transport_error(self, g, nb_montecarlo=100_000, chunk_size=10_000):
        total_error = 0.0
        total_samples = 0
        g = g.detach()
        g_opt = self.theoretical_g_opt.detach()

        with torch.no_grad():
            for _ in range((nb_montecarlo + chunk_size - 1) // chunk_size):
                current_chunk = min(chunk_size, nb_montecarlo)
                nb_montecarlo -= current_chunk

                # Sample a batch of x points
                x = self.draw_mu(current_chunk, self.dim)  # [N, d]

                # Compute pairwise costs: [N, M]
                cost = self.cost_function(x, self.nu_points)  # [N, M]

                # Compute OT maps using argmin along M
                theoretical_idx = torch.argmin(cost - g_opt, dim=1)  # [N]
                approx_idx = torch.argmin(cost - g, dim=1)           # [N]

                # Gather transported points: [N, d]
                theoretical_y = self.nu_points[theoretical_idx]
                approx_y = self.nu_points[approx_idx]

                # Compute squared error in batch
                total_error += torch.sum((theoretical_y - approx_y) ** 2).item()
                total_samples += current_chunk

        return total_error / total_samples
    
    def cost_error(self, g, nb_loop=10, nb_montecarlo=10_000):
        with torch.no_grad():
            g = g.detach()
            g_opt = self.theoretical_g_opt.detach()
            total_error = 0.0

            for _ in range(nb_loop):
                X = self.draw_mu(nb_montecarlo, self.dim)  # [N, d]

                # Compute the cost matrix only once
                cost = 0.5 * (X.unsqueeze(1) - self.nu_points.unsqueeze(0)).pow(2).sum(dim=2)  # [N, M]

                # Dual estimates
                min_g = torch.min(cost - g, dim=1)[0]        # [N]
                min_gopt = torch.min(cost - g_opt, dim=1)[0] # [N]

                dual_g = torch.mean(min_g) + torch.sum(g * self.nu_weights)
                dual_gopt = torch.mean(min_gopt) + torch.sum(g_opt * self.nu_weights)

                total_error += torch.abs(dual_g - dual_gopt)

            return total_error / nb_loop


        



