import torch
import numpy as np
import torch.nn as nn
from scipy.special import softmax, logsumexp

class Solver:
    def solve(self, dim, draw_samples_func, nu_points, nu_weights, max_cost):
        raise NotImplementedError
    
    def change_parameters(self, params):
        raise NotImplementedError

class DRPASGDSolver(Solver):
    def __init__(self): 
        self.g_0_init = None
        self.decreasing = True
        self.eps_0 = 0.1
        self.a = 1/3
        self.b = 2/3
        self.proj = True
        self.batch_size = 1
        self.numItermax = 10**4
        self.num_save = 100
        self.gamma_0 = None
        self.precision = 'float64'
        self.g_0_init = None
        self.treshold = False

        dtype_map = {'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64}
        self.dtype = dtype_map.get(self.precision, torch.float32)

    def change_parameters(self, params):
        for param_name, param_value in params.items():
            if hasattr(self, param_name):
                setattr(self, param_name, param_value)

    def print_parameters(self):
        print(f"Parameters of the solver: \n"
              f"g_0_init: {self.g_0_init}\n"
              f"decreasing: {self.decreasing}\n"
              f"eps_0: {self.eps_0}\n"
              f"a: {self.a}\n"
              f"b: {self.b}\n"
              f"proj: {self.proj}\n"
              f"batch_size: {self.batch_size}\n"
              f"numItermax: {self.numItermax}\n"
              f"num_save: {self.num_save}\n"
              f"gamma_0: {self.gamma_0}\n"
              f"precision: {self.precision}\n")

    def solve(self, dim, draw_samples_func, nu_points, nu_weights, max_cost):
        device = torch.device('cpu')
        print("\nUsing DRAG\n")
        # Convert the discrete target measure to the specified precision
        discrete_diracs = torch.tensor(nu_points, dtype=self.dtype).to(device)
        m = discrete_diracs.shape[0]
        vect_nu = torch.tensor(nu_weights, dtype=self.dtype).to(device) if nu_weights is not None else torch.ones(m, device=device, dtype=self.dtype) / m
        # Initialize potentials g with g_0_init or at 0 if not provided
        g_t = torch.tensor(self.g_0_init, device=device, dtype=self.dtype) if self.g_0_init is not None else torch.rand(m, device=device, dtype=self.dtype)

        # Initialize g_bar
        g_bar = g_t.clone()

        # Initialize weighted g_bar
        g_bar_weighted = g_t.clone()

        # Initialize histories
        g_t_history = [g_t.cpu().clone()]
        g_bar_history = [g_t.cpu().clone()]
        g_bar_weighted_history = [g_t.cpu().clone()]

        eps = self.eps_0


        if self.gamma_0 is None: 
            self.gamma_0 = np.sqrt(self.batch_size*m)

        w = 3
        weighted_weights = np.array([np.log(i + 1)**w for i in range(self.numItermax + 1)])
        weighted_weights /= np.cumsum(weighted_weights)

        with torch.no_grad():  # No gradient computations with PyTorch needed
            for t in range(1, int(self.numItermax) + 1):
                if self.decreasing:
                    eps = self.eps_0/(t**self.a)

                # Draw samples from the source measure
                y_batch = draw_samples_func(self.batch_size, dim, precision = self.dtype)

                # Compute the cost for the batch
                cost = 0.5 * torch.sum((discrete_diracs.unsqueeze(0) - y_batch.unsqueeze(1))**2, dim=2)

                if self.treshold and eps < 0.002: 
                    # Compute the argmax subgradient of the non-regularized problem after a certain treshold
                    argmax_indices = torch.argmax(g_t.unsqueeze(0) - cost, dim=1)
                    chi = torch.zeros_like(cost).scatter_(1, argmax_indices.unsqueeze(1), 1)
                    grad = torch.mean(chi, dim=0) - vect_nu
                else: 
                    # Compute the softmax and the gradient
                    chi = nn.functional.softmax((g_t.unsqueeze(0) - cost) / eps, dim=1)
                    chi_weighted = chi * vect_nu
                    chi_sum = torch.sum(chi_weighted, dim=1)
                    chi = chi_weighted / chi_sum.unsqueeze(1)
                    grad = torch.mean(chi, dim=0) - vect_nu


                # g_t update
                step_size = self.gamma_0 / t ** self.b
                g_t -= step_size * grad

                # Projection step 
                if self.proj == True:
                    g_t = torch.clamp(g_t, -max_cost, max_cost)

                # g_bar update
                g_bar += (g_t - g_bar) / t

                # weighted g_bar update
                g_bar_weighted += weighted_weights[t] * (g_t - g_bar_weighted)

                # Save history every 100 iterations
                mod = 100 if self.numItermax <= 20_000 else 1000
                if t % mod == 0 or t == 10:
                    g_t_history.append(g_t.cpu().clone())
                    g_bar_history.append(g_bar.cpu().clone())
                    g_bar_weighted_history.append(g_bar_weighted.cpu().clone())

        return g_t.cpu(), g_bar.cpu(), g_t_history, g_bar_history, g_bar_weighted_history 

