import torch
from tqdm import tqdm
from .base import Algo
from .enkg import ode_sampler

class EnSGS(Algo):
    '''
    Current implementation of SURF 2024 project.
    '''
    
    def __init__(self, 
             net,
             forward_op,
             guidance_scale,
             num_steps,
             likelihood_steps,
             prior_steps,
             rho_min,
             rho_max,
             rho_decay,
             num_samples=1024,
             batch_size=64,
             resample=True):
        super(EnSGS, self).__init__(net, forward_op)
        self.scale = guidance_scale
        self.N = num_steps
        self.num_l_steps = likelihood_steps
        self.num_prior_steps = prior_steps
        self.rho_min = rho_min
        self.rho_max = rho_max
        self.rho_decay = rho_decay
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.resample = resample
        
    @torch.no_grad()
    def inference(self, observation, num_samples=1, **kwargs):
        device = self.forward_op.device
        x_initial = torch.randn(self.num_samples, self.net.img_channels, self.net.img_resolution, self.net.img_resolution, device=device) * self.rho_max
        
        rho_schedule = torch.zeros(self.N)
        rho_schedule = torch.pow(self.rho_decay, torch.arange(0, self.N)) * self.rho_max
        rho_schedule = torch.maximum(rho_schedule, torch.ones_like(rho_schedule) * self.rho_min)
        
        num_batches = len(x_initial) // self.batch_size
        x = x_initial
        for i in range(self.N):
            rho_cur = rho_schedule[i]
            
            # Likelihood step
            z = self.ll_step(observation, x, rho_cur)
                                
            # Prior Step
            x = torch.zeros_like(z)
            pbar = tqdm(range(num_batches))
            for b in pbar:
                start = b * self.batch_size
                end = (b + 1) * self.batch_size
                x[start : end] = ode_sampler(self.net, z[start : end], self.num_prior_steps, sigma_start=rho_cur)
                
        return x
    
    @torch.no_grad()
    def ll_step(self, y, particles, rho):
        x = particles
        z_next = particles.clone()
        J, *spatial = particles.shape
        
        pbar = tqdm(range(self.num_l_steps))
        for _ in pbar:
            
            z_diff = (z_next - z_next.mean(dim=0, keepdim=True)).reshape(J, -1)
            cov = z_diff.T @ z_diff / len(z_diff)
            # cov_norm = torch.linalg.matrix_norm(cov)
            
            dz_reg = ((x - z_next).reshape(J, -1) @ cov).reshape(J, *spatial) / (rho ** 2)      
            dz_ll, lr_ll = self.ek_update(self.forward_op, y, self.forward_op.sigma_noise, z_next, z_next, 1.0)        
            
            lr = self.scale / torch.linalg.matrix_norm((dz_ll + dz_reg).reshape(J, -1))
            
            z_next -= dz_ll * lr
            z_next += dz_reg * lr

            # cov_sqrt = torch.linalg.cholesky((cov / cov_norm) + 0.001 * torch.eye(len(cov), device=z_next.device))
            cov_sqrt = torch.linalg.cholesky(cov + 0.01 * torch.eye(len(cov), device=z_next.device))
            # print(cov.mean().item(), cov_sqrt.mean().item())
            eps = torch.randn_like(z_next).reshape(J, -1)
            noise = (eps @ cov_sqrt).reshape(J, *spatial) * torch.sqrt(2 * lr)
            z_next += noise
            
            avg_err = (self.forward_op.forward(z_next) - y).mean()
            pbar.set_description(f'Avg. error: {avg_err.item()}')
        
        if self.resample:
            noise_diff = rho
        else:
            noise_diff = 0
        return z_next + torch.randn_like(z_next) * noise_diff
            
    @torch.no_grad()
    def ek_update(self, forward_operator, y, std_y, x, x_clean, scale):
    
        N, *spatial = x.shape
        
        preds = forward_operator.forward(x_clean)
        xs_diff = x - x.mean(dim=0, keepdim=True)
        pred_err = (preds - y)
        pred_diff = preds - preds.mean(dim=0, keepdim=True)
            
        coef = (
            torch.matmul(
                pred_err.reshape(pred_err.shape[0], -1) / (std_y ** 2),
                pred_diff.reshape(pred_diff.shape[0], -1).T,
            )
            / len(x)
        )
        
        dx = (coef @ xs_diff.reshape(N, -1)).reshape(N, *spatial)
        lr = scale / torch.linalg.matrix_norm(coef)
        
        return dx, lr