from abc import ABC, abstractmethod
import torch

class ZerothOrderOptimizer(ABC):

    @abstractmethod
    def approximate_gradient(self, closure):
        pass

    @abstractmethod
    def optimize(self):
        pass

    def step(self, closure):
        self.approximate_gradient(closure)
        self.optimize()

    @abstractmethod
    def state_dict(self):
        pass
    
    @property
    @abstractmethod
    def batched(self):
        pass
    
    @abstractmethod
    def reset():
        pass   


def finite_difference(param, idx, closure, eps=1e-6, central=True):
    w = param.view(-1)[idx]
    # eps = eps * torch.abs(w) if torch.abs(w) > 1e-6 else eps
    if central:
        w += eps
        y1 = closure()
        w -= 2*eps
        y0 = closure()
        w += eps
        return (y1 - y0)/(2*eps)
    else:
        w += eps
        y1 = closure()
        w -= eps
        y0 = closure()
        return (y1 - y0) / eps

 
    
from torch.func import vmap    
def batched_finite_difference(params, tasks, closure, eps=1e-6, central=True, max_batch_size=None):
    if not tasks:
        return None
    if max_batch_size is None or max_batch_size <= 0:
        max_batch_size = len(tasks)
    outputs = None
    for start in range(0, len(tasks), max_batch_size):
        end = min(start + max_batch_size, len(tasks))
        chunk_tasks = tasks[start:end]
        chunk_params = {k: v[start:end] for k, v in params.items()}
        for i, (p_name, w_idx) in enumerate(chunk_tasks):
            chunk_params[p_name][i].view(-1)[w_idx] += eps
        losses_plus = vmap(closure)(chunk_params)
        if outputs is None:
            outputs = torch.empty(len(tasks), device=losses_plus.device, dtype=losses_plus.dtype)
        if central:
            for i, (p_name, w_idx) in enumerate(chunk_tasks):
                chunk_params[p_name][i].view(-1)[w_idx] -= 2 * eps
            losses_minus = vmap(closure)(chunk_params)
            outputs[start:end] = (losses_plus - losses_minus) / (2 * eps)
        else:
            for i, (p_name, w_idx) in enumerate(chunk_tasks):
                chunk_params[p_name][i].view(-1)[w_idx] -= eps     
            losses_base = vmap(closure)(chunk_params)
            outputs[start:end] = (losses_plus - losses_base) / eps
    return outputs





