import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# Define global precision map
DTYPE_MAP = {
    'float16': torch.float16,
    'float32': torch.float32,
    'float64': torch.float64
}


class Solver:
    def solve(self, dim, draw_samples_func, nu_points, nu_weights, max_cost):
        raise NotImplementedError
    
    def change_parameters(self, params):
        raise NotImplementedError

class PSGDSolver(Solver):
    def __init__(self): 
        self.cost = "quadra"
        self.g_0_init = None
        self.step_decay = 3/4
        self.proj = True
        self.batch_size = 1
        self.numItermax = 10**4
        self.num_save = 100
        self.initial_step_size = None
        self.precision = 'float64'
        self.dtype = DTYPE_MAP.get(self.precision, torch.float32)

    def change_parameters(self, params):
        for param_name, param_value in params.items():
            if hasattr(self, param_name):
                setattr(self, param_name, param_value)

    def print_parameters(self):
        print(f"Parameters of the solver: \n"
              f"g_0_init: {self.g_0_init}\n"
              f"initial_step_size: {self.initial_step_size}\n"
              f"step_decay: {self.step_decay}\n"
              f"proj: {self.proj}\n"
              f"batch_size: {self.batch_size}\n"
              f"numItermax: {self.numItermax}\n"
              f"num_save: {self.num_save}\n"
              f"precision: {self.precision}\n")

    def solve(self, dim, draw_samples_func, nu_points, nu_weights, max_cost):
        """
        Solves the  semi-discrete OT problem problem with Projected (A)SGD

        Args:
            dim (int): Dimension of the input space.
            draw_samples_func (Callable): Function to draw samples from the source distribution.
            nu_points (np.ndarray): Discrete target measure points.
            nu_weights (Optional[np.ndarray]): Weights associated with the target measure points, default is (1,...,1)/nu_points
            max_cost (float): Maximum cost for the projection step.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
                Final g_t, g_bar and their histories.
        """
        device = torch.device('cpu')

        # Convert the discrete target measure to the specified precision
        discrete_diracs = torch.tensor(nu_points, dtype=self.dtype).to(device)
        m = discrete_diracs.shape[0]
        vect_nu = torch.tensor(nu_weights, dtype=self.dtype).to(device) if nu_weights is not None else torch.ones(m, device=device, dtype=self.dtype) / m
        # Initialize potentials g with g_0_init or at random between -max_cost and max_cost if not provided
        g_t = (torch.tensor(self.g_0_init, device=device, dtype=self.dtype) 
            if self.g_0_init is not None 
            else (torch.rand(m, device=device, dtype=self.dtype) * 2 - 1) * max_cost)
        # Initialize g_bar
        g_bar = g_t.clone()


        # Histories for the article's numerical experiments
        histories = {"g_t": [g_t.cpu().clone()], "g_bar": [g_t.cpu().clone()]}

        if self.initial_step_size is None: 
            self.initial_step_size = np.sqrt(self.batch_size*m)


        with torch.no_grad():  # No gradient computations with PyTorch needed
            for t in range(1, int(self.numItermax) + 1):
                # Draw samples from the source measure
                y_batch = draw_samples_func(self.batch_size, dim, precision = self.dtype)

                # Compute the cost for the batch
                if self.cost == "quadra": 
                    cost = 0.5*torch.sum((discrete_diracs.unsqueeze(0) - y_batch.unsqueeze(1))**2, dim=2)
                elif self.cost == "1.5": 
                    cost = torch.sum((discrete_diracs.unsqueeze(0) - y_batch.unsqueeze(1))**2, dim=2)**(0.75)
                else: 
                    raise ValueError(f"Unsupported cost type: {self.cost}. Supported values are 'quadra' and '1.5'.")

                # Compute the (sub)gradient
                argmax_indices = torch.argmax(g_t.unsqueeze(0) - cost, dim=1)
                # Create the one-hot encoded tensor in-place
                chi = torch.zeros_like(cost).scatter_(1, argmax_indices.unsqueeze(1), 1)
                grad = torch.mean(chi, dim=0) - vect_nu
                

                # g_t update
                step_size = self.initial_step_size / t ** self.step_decay
                g_t -= step_size * grad

                # Projection step 
                if self.proj:
                    g_t = torch.clamp(g_t, -max_cost, max_cost)

                # g_bar update
                g_bar += (g_t - g_bar) / t


                # Save history every 100 iterations
                if t % 100 == 0:
                    for key, value in zip(histories.keys(), [g_t, g_bar]):
                        histories[key].append(value.cpu().clone())

        return {
            "g_t": g_t.cpu(),
            "g_bar": g_bar.cpu(),
            "histories": histories
        }
    