import torch
import tqdm
from .base import Algo
import numpy as np
import wandb


class DiffPIR_GSG(Algo):
    def __init__(self, net, forward_op, num_steps, sigma_max, sigma_min, rho, 
                 sigma_n, lamb, xi, num_queries, mu, batch_size, is_central):
        super(DiffPIR_GSG, self).__init__(net, forward_op)
        self.num_steps = num_steps
        self.sigma_max = sigma_max
        self.sigma_min = sigma_min
        self.rho = rho
        self.net = net
        self.forward_op = forward_op
        self.sigma_n = sigma_n
        self.lamb = lamb
        self.xi = xi
        self.num_queries = num_queries
        self.mu = mu
        self.batch_size = batch_size
        self.is_central = is_central
        
    @torch.no_grad()
    def fgsg_estimate(self, observation, x_0):
        num_batches = self.num_queries // self.batch_size
        
        def f(x):
            flat = torch.flatten((observation - self.forward_op.forward(x)), start_dim=1)
            return torch.linalg.norm(flat, dim=1)
        
        grad_est = torch.zeros_like(x_0)
        norm = f(x_0) # torch.Size([num_samples])
        
        shape = x_0.shape
        
        for i in range(len(x_0)):
            for j in range(num_batches):
                u = torch.randn((self.batch_size, shape[1], shape[2], shape[3]), device=x_0.device)
                
                x0_perturbed = x_0[i] + self.mu * u # batch_size x C x H x W
                perturbed_norm = f(x0_perturbed) # torch.Size([batch_size])
                            
                diff = (perturbed_norm - norm[i]).reshape(self.batch_size, 1, 1, 1)
                prod = u * (diff / (self.mu * self.num_queries))
                grad_est[i] += prod.sum(dim=0, keepdim=True).squeeze(0)
                
        return grad_est
    
    @torch.no_grad()
    def cgsg_estimate(self, observation, x_0):
        num_batches = self.num_queries // self.batch_size
        
        def f(x):
            flat = torch.flatten((observation - self.forward_op.forward(x)), start_dim=1)
            return torch.linalg.norm(flat, dim=1)
        
        grad_est = torch.zeros_like(x_0)
        shape = x_0.shape
        
        for i in range(len(x_0)):
            for j in range(num_batches):
                u = torch.randn((self.batch_size, shape[1], shape[2], shape[3]), device=x_0.device)
                
                x0_perturbed_plus = x_0[i] + self.mu * u
                x0_perturbed_minus = x_0[i] - self.mu * u
                perturbed_norm_plus = f(x0_perturbed_plus)
                perturbed_norm_minus = f(x0_perturbed_minus)
                            
                diff = (perturbed_norm_plus - perturbed_norm_minus).reshape(self.batch_size, 1, 1, 1)
                prod = u * (diff / (self.mu * self.num_queries))
                grad_est[i] += prod.sum(dim=0, keepdim=True).squeeze(0)
                
        return grad_est
        
    @torch.no_grad()
    def inference(self, observation, num_samples=1, **kwargs):
        device = self.forward_op.device
        step_indices = torch.arange(self.num_steps, dtype=torch.float64, device=device)
        t_steps = (self.sigma_max ** (1 / self.rho) + step_indices / (self.num_steps - 1) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))) ** self.rho
        t_steps = torch.cat([self.net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])

        pbar = tqdm.trange(self.num_steps)
        xt= torch.randn(observation.shape[0], self.net.img_channels, self.net.img_resolution, self.net.img_resolution, device=device) * self.sigma_max
        for step in pbar:
            sigma = t_steps[step]
            x0 = self.net(xt, sigma).to(torch.float64).clone().requires_grad_(True)
            grad = self.cgsg_estimate(observation, x0) if self.is_central else self.fgsg_estimate(observation, x0)
            x0hat = x0 - sigma**2 / (2*self.lamb*self.sigma_n**2) * grad

            effect = (xt - x0hat)/sigma
            xt = x0hat + (np.sqrt(self.xi)* torch.randn_like(xt) + np.sqrt(1-self.xi)*effect) * t_steps[step+1]
        return xt