import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import numpy as np
import copy

class ZO_GD(torch.optim.Optimizer):
    def __init__(self, model, params, inputs, labels, criterion, lr=1e-5, eps=1e-8, fd_eps=1e-4, use_true_grad=False):
        defaults = dict(lr=lr, eps=eps, fd_eps=fd_eps, use_true_grad=use_true_grad)
        self.model = model
        self.criterion = criterion
        self.inputs = inputs
        self.labels = labels
        super().__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            fd_eps = group['fd_eps']
            use_true_grad = group['use_true_grad']

            for param in group['params']:
                if use_true_grad:
                    grad_est = param.grad.data
                else:
                    full_grad = self._compute_full_gradient(param, fd_eps)
                    grad_est = full_grad
                    # grad_norm = torch.norm(grad_est)
                    # print(f'Gradient Norm: {grad_norm.item()}')

                param.data.add_(-lr * grad_est)

    def _compute_full_gradient(self, param, fd_eps):
        grad_est = torch.zeros_like(param.data)
        
        inputs_tensor = torch.Tensor(self.inputs)
        labels_tensor = torch.Tensor(self.labels).unsqueeze(dim=1)

        for i in range(len(self.inputs)):
            grad_est += self._compute_gradient_direction(param, fd_eps, inputs_tensor[i], labels_tensor[i])
        
        grad_est /= len(self.inputs)
        return grad_est

    def _compute_gradient_direction(self, param, fd_eps, input_sample, label_sample):
        grad_est = torch.zeros_like(param.data)
        orig_param = param.data.clone()

        original_loss = self.criterion(self.model(input_sample), label_sample)

        for i in range(param.data.numel()):
            param.data.view(-1)[i] += fd_eps
            loss_plus = self.criterion(self.model(input_sample), label_sample)

            param.data.view(-1)[i] -= 2 * fd_eps
            loss_minus = self.criterion(self.model(input_sample), label_sample)

            grad_est.view(-1)[i] = (loss_plus - loss_minus) / (2 * fd_eps)
            param.data = orig_param.clone()

        return grad_est

 
    

class ZO_MiniBatch_SGD(torch.optim.Optimizer):
    def __init__(self, model, params, inputs, labels, criterion, lr=1e-5, batch_size=32, eps=1e-8, fd_eps=1e-4, use_true_grad=False):
        defaults = dict(lr=lr, eps=eps, fd_eps=fd_eps, use_true_grad=use_true_grad, batch_size=batch_size)
        self.model = model
        self.criterion = criterion
        self.inputs = inputs
        self.labels = labels
        super().__init__(params, defaults)
    
    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            fd_eps = group['fd_eps']
            use_true_grad = group['use_true_grad']
            batch_size = min(group['batch_size'], len(self.inputs))
            
            # Sample a mini-batch of random data points
            indices = np.random.choice(len(self.inputs), batch_size, replace=False)
            input_samples = torch.Tensor(self.inputs[indices])
            label_samples = torch.Tensor(self.labels[indices])
            
            for param in group['params']:
                if use_true_grad:
                    grad_est = param.grad.data
                else:
                    grad_est = self._compute_mini_batch_gradient(param, fd_eps, input_samples, label_samples)
                    # grad_norm = torch.norm(grad_est)
                    # print(f'Gradient Norm: {grad_norm.item()}')
                
                param.data.add_(-lr * grad_est)
    
    def _compute_mini_batch_gradient(self, param, fd_eps, input_samples, label_samples):
        grad_est = torch.zeros_like(param.data)
        orig_param = param.data.clone()

        input_samples = torch.Tensor(input_samples)
        label_samples = torch.Tensor(label_samples).unsqueeze(1)
        
        for i in range(len(input_samples)):
            input_sample = input_samples[i]
            label_sample = label_samples[i]
            
            original_loss = self.criterion(self.model(input_sample), label_sample)
            
            for j in range(param.data.numel()):
                param.data.view(-1)[j] += fd_eps
                loss_plus = self.criterion(self.model(input_sample), label_sample)
                
                param.data.view(-1)[j] -= 2 * fd_eps
                loss_minus = self.criterion(self.model(input_sample), label_sample)
                
                grad_est.view(-1)[j] += (loss_plus - loss_minus) / (2 * fd_eps)
                
            param.data = orig_param.clone()
        
        grad_est /= len(input_samples)
        return grad_est


class ZO_SVRG(torch.optim.Optimizer):
    def __init__(self, model, params, inputs, labels, criterion, lr=1e-5, batch_size=32, m=10, eps=1e-8, fd_eps=1e-4, use_true_grad=False):
        defaults = dict(lr=lr, eps=eps, fd_eps=fd_eps, use_true_grad=use_true_grad, batch_size=batch_size, m=m)
        self.model = model
        self.criterion = criterion
        self.inputs = inputs
        self.labels = labels
        self.counter = 0
        super().__init__(params, defaults)
        self.update_snapshot_and_full_grad()

    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            fd_eps = group['fd_eps']
            use_true_grad = group['use_true_grad']
            batch_size = min(group['batch_size'], len(self.inputs))
            m = group['m']
            
            # Update snapshot and full gradient every m iterations
            if self.counter % m == 0:
                self.update_snapshot_and_full_grad()
            self.counter += 1
            
            # Sample a mini-batch of random data points
            indices = np.random.choice(len(self.inputs), batch_size, replace=False)
            input_samples = torch.Tensor(self.inputs[indices])
            label_samples = torch.Tensor(self.labels[indices])

            for param in group['params']:
                if use_true_grad:
                    grad_est = param.grad.data
                else:
                    stoch_grad = self._compute_mini_batch_gradient(param, fd_eps, input_samples, label_samples)
                    stoch_grad_old_params = self._compute_mini_batch_gradient(self.param_snapshot[param], fd_eps, input_samples, label_samples)
                    grad_est = stoch_grad - stoch_grad_old_params + self.full_grad[param]
                
                param.data.add_(-lr * grad_est)

    def update_snapshot_and_full_grad(self):
        self.full_grad = {}
        self.param_snapshot = {}
        for group in self.param_groups:
            for param in group['params']:
                self.full_grad[param] = self._compute_full_gradient(param, group['fd_eps'])
                self.param_snapshot[param] = param.data.clone()

    def _compute_full_gradient(self, param, fd_eps):
        grad_est = torch.zeros_like(param.data)
        
        inputs_tensor = torch.Tensor(self.inputs)
        labels_tensor = torch.Tensor(self.labels).unsqueeze(dim=1)

        for i in range(len(self.inputs)):
            grad_est += self._compute_gradient_direction(param, fd_eps, inputs_tensor[i], labels_tensor[i])
        
        grad_est /= len(self.inputs)
        return grad_est

    def _compute_gradient_direction(self, param, fd_eps, input_sample, label_sample):
        grad_est = torch.zeros_like(param.data)
        orig_param = param.data.clone()

        original_loss = self.criterion(self.model(input_sample), label_sample)

        for i in range(param.data.numel()):
            param.data.view(-1)[i] += fd_eps
            loss_plus = self.criterion(self.model(input_sample), label_sample)

            param.data.view(-1)[i] -= 2 * fd_eps
            loss_minus = self.criterion(self.model(input_sample), label_sample)

            grad_est.view(-1)[i] = (loss_plus - loss_minus) / (2 * fd_eps)
            param.data = orig_param.clone()

        return grad_est

    def _compute_mini_batch_gradient(self, param, fd_eps, input_samples, label_samples):
        grad_est = torch.zeros_like(param.data)
        orig_param = param.data.clone()

        input_samples = torch.Tensor(input_samples)
        label_samples = torch.Tensor(label_samples).unsqueeze(1)
        
        for i in range(len(input_samples)):
            input_sample = input_samples[i]
            label_sample = label_samples[i]
            
            original_loss = self.criterion(self.model(input_sample), label_sample)
            
            for j in range(param.data.numel()):
                param.data.view(-1)[j] += fd_eps
                loss_plus = self.criterion(self.model(input_sample), label_sample)
                
                param.data.view(-1)[j] -= 2 * fd_eps
                loss_minus = self.criterion(self.model(input_sample), label_sample)
                
                grad_est.view(-1)[j] += (loss_plus - loss_minus) / (2 * fd_eps)
                
            param.data = orig_param.clone()
        
        grad_est /= len(input_samples)
        return grad_est


