from zeroth_order_optim import ZerothOrderOptimizer, finite_difference, batched_finite_difference
from torch.optim.optimizer import ParamsT
import torch
import torch.optim as optim
import pickle

class CoherentCoordinatesDescent(ZerothOrderOptimizer, optim.Optimizer):
    batched = False
    def __init__(self,
                 params: ParamsT,
                 param_names,
                 lr,
                 weight_decay,
                 init_grad,
                 eps,
                 compute_budget,
                 memory_budget:int=None,
                 momentum=1.0,
                 central_fd=True,
                 batched=False,
                 max_batch_size: int = None
                ):
        self.params = list(params)
        self.device = self.params[0].device
        self.lr = lr
        if weight_decay < 0:
            weight_decay = 0
        self.weight_decay = weight_decay
        self.c_budget = compute_budget
        self.acc_p_numels = torch.zeros(len(self.params)+1,dtype=torch.int)
        for i in range(len(self.params)):
            self.acc_p_numels[i+1] = self.acc_p_numels[i] + self.params[i].numel()
        if memory_budget is None:
            self.m_budget = self.acc_p_numels[-1]
        else:
            self.m_budget = memory_budget
        if init_grad is None:
            self.grad = torch.zeros(self.m_budget, device=self.device)
        else:
            self.grad = init_grad.to(self.device)
        print(f"number of parameters: {self.acc_p_numels[-1]}")
        # used for approximating gradients
        self.cur_param_idx = 0
        self.cur_weight_idx = 0
        self.cur_grad_idx = 0
        # used for doing descent
        self.full_memory_used = False
        self.grad_offset = 0
        self.param_offset= 0
        self.weight_offset = 0

        self.eps = eps
        self.momentum = momentum
        self.central_fd = central_fd
        self.batched = batched
        if max_batch_size is None:
            self.max_batch_size = self.c_budget
        else:
            self.max_batch_size = max_batch_size
        # Map params to names for functional dictionary creation
        self.param_names = param_names
        # Add this to __init__
        self.batched_buffers = {}
        if self.batched:
            for name, param in zip(self.param_names, self.params):
                expanded_shape = (self.c_budget,) + param.shape
                self.batched_buffers[name] = torch.zeros(expanded_shape, device=self.device, dtype=param.dtype)
        defaults = dict(lr=lr, weight_decay=self.weight_decay, momentum=momentum, eps=eps)
        optim.Optimizer.__init__(self, self.params, defaults)

    def step(self, closure=None):
        if closure is None:
            raise ValueError("CoherentCoordinatesDescent requires a closure for zeroth-order updates.")
        self.approximate_gradient(closure)
        self.optimize()
        return None

    def approximate_gradient(self, closure):
        self.grad *= self.momentum
        if self.batched:
            tasks = []
            grad_indices = []
            for _ in range(self.c_budget):
                p_name = self.param_names[self.cur_param_idx]
                tasks.append((p_name, self.cur_weight_idx))
                grad_indices.append(self.cur_grad_idx)
                self.update_pointers()
            for name, param in zip(self.param_names, self.params):
                self.batched_buffers[name].copy_(param.unsqueeze(0))
            self.grad[torch.tensor(grad_indices, device=self.device)] = \
                batched_finite_difference(self.batched_buffers, tasks, closure, self.eps, self.central_fd, self.max_batch_size)
                
        else:
            for j in range(self.c_budget):
                self.grad[self.cur_grad_idx] = \
                        finite_difference(self.params[self.cur_param_idx],
                                                    self.cur_weight_idx,
                                                    closure,
                                                    self.eps,
                                                    self.central_fd)
                self.update_pointers()
                
    def update_pointers(self):
        if self.full_memory_used:
            self.grad_offset = (self.grad_offset + 1) % self.m_budget
            self.weight_offset = self.weight_offset + 1
            if self.weight_offset == self.params[self.param_offset].numel():
                self.weight_offset = 0
                self.param_offset = (self.param_offset + 1) % len(self.params)
        if (not self.full_memory_used) and self.cur_grad_idx + 1 == self.m_budget:
                self.full_memory_used = True          
        self.cur_grad_idx = (self.cur_grad_idx + 1) % self.m_budget
        self.cur_weight_idx = self.cur_weight_idx + 1
        if self.cur_weight_idx == self.params[self.cur_param_idx].numel():
            self.cur_weight_idx = 0
            self.cur_param_idx = (self.cur_param_idx + 1) % len(self.params)
                    



    def optimize(self):
        grad_offset = self.grad_offset
        param_offset = self.param_offset
        weight_offset = self.weight_offset
        cnt = 0
        while cnt < self.m_budget:
            weights_len = self.params[param_offset].numel() - weight_offset
            if cnt + weights_len > self.m_budget:
                weights_len = self.m_budget - cnt
            self.params[param_offset].view(-1)[weight_offset:weight_offset+weights_len] -= \
                self.lr * self.grad[(torch.arange(weights_len, device=self.device)+grad_offset)%self.m_budget]
            grad_offset = (grad_offset + weights_len) % self.m_budget
            param_offset = (param_offset + 1) % len(self.params)
            weight_offset = 0
            cnt += weights_len
            
    def reset(self):
        self.grad.zero_()
        self.cur_param_idx = 0
        self.cur_weight_idx = 0
        self.cur_grad_idx = 0
        self.full_memory_used = False
        self.grad_offset = 0
        self.param_offset = 0
        self.weight_offset = 0        

    def state_dict(self):
        base_state = optim.Optimizer.state_dict(self)
        base_state["coherent_state"] = {
            "grad": self.grad,
            "cur_param_idx": self.cur_param_idx,
            "cur_weight_idx": self.cur_weight_idx,
            "cur_grad_idx": self.cur_grad_idx,
            "full_memory_used": self.full_memory_used,
            "grad_offset": self.grad_offset,
            "param_offset": self.param_offset,
            "weight_offset": self.weight_offset,
        }
        return base_state

    def load_state_dict(self, state_dict):
        """Load optimizer state (gradient buffer + pointer bookkeeping)."""
        if state_dict is None:
            return
        coherent_state = state_dict.pop("coherent_state", {})
        optim.Optimizer.load_state_dict(self, state_dict)
        grad = coherent_state.get("grad")
        if grad is not None:
            self.grad = grad.to(self.device)
        self.cur_param_idx = coherent_state.get("cur_param_idx", self.cur_param_idx)
        self.cur_weight_idx = coherent_state.get("cur_weight_idx", self.cur_weight_idx)
        self.cur_grad_idx = coherent_state.get("cur_grad_idx", self.cur_grad_idx)
        self.full_memory_used = coherent_state.get("full_memory_used", self.full_memory_used)
        self.grad_offset = coherent_state.get("grad_offset", self.grad_offset)
        self.param_offset = coherent_state.get("param_offset", self.param_offset)
        self.weight_offset = coherent_state.get("weight_offset", self.weight_offset)