import torch
import tqdm
from .base import Algo
import numpy as np
import wandb
from utils.scheduler import Scheduler
from utils.diffusion import DiffusionSampler
import warnings

GRAD_INTERFACE = True

class REDDiff(Algo):
    def __init__(self, net, forward_op, num_steps=1000, observation_weight=1.0, base_lambda=0.25, base_lr=0.5, lambda_scheduling_type='constant'):
        super(REDDiff, self).__init__(net, forward_op)
        self.net = net
        self.net.eval().requires_grad_(False)
        self.forward_op = forward_op

        self.scheduler = Scheduler(num_steps=num_steps, schedule='vp', timestep='vp', scaling='vp')
        self.base_lr = base_lr
        self.observation_weight = observation_weight
        if lambda_scheduling_type == 'linear':
            self.lambda_fn = lambda sigma: sigma * base_lambda
        elif lambda_scheduling_type == 'sqrt':
            self.lambda_fn = lambda sigma: np.sqrt(sigma) * base_lambda
        elif lambda_scheduling_type == 'constant':
            self.lambda_fn = lambda sigma: base_lambda
        else:
            raise NotImplementedError

    def pred_epsilon(self, model, x, sigma):
        sigma = torch.as_tensor(sigma).to(x.device)
        d = model(x, sigma)
        return (x - d) / sigma

    def inference(self, observation, num_samples=1, verbose=True):
        device = self.forward_op.device
        num_steps = self.scheduler.num_steps
        pbar = tqdm.trange(num_steps) if verbose else range(num_steps)
        if num_samples > 1:
            observation = observation.repeat(num_samples, 1, 1, 1)

        # 0. random initialization (instead of from pseudo-inverse)
        mu = torch.zeros(num_samples, self.net.img_channels, self.net.img_resolution, self.net.img_resolution,
                         device=device).requires_grad_(True)
        optimizer = torch.optim.Adam([mu], lr=self.base_lr, betas=(0.9, 0.99))
        # optimizer = torch.optim.SGD([mu], lr=self.base_lr)
        for step in pbar:
            # 1. forward diffusion
            with torch.no_grad():
                sigma, scaling = self.scheduler.sigma_steps[step], self.scheduler.scaling_steps[step]
                epsilon = torch.randn_like(mu)
                xt = scaling * (mu + sigma * epsilon)
                pred_epsilon = self.pred_epsilon(self.net, xt, sigma).detach()

            # 2. regularized optimization
            lam = self.lambda_fn(sigma)  # sigma here equals to 1/SNR
            optimizer.zero_grad()
            if GRAD_INTERFACE:
                gradient = self.forward_op.gradient(mu, observation) * self.observation_weight
                gradient += lam * (pred_epsilon - epsilon)
                mu.grad = gradient
            else:
                loss = self.forward_op.loss(mu, observation).sum() * self.observation_weight
                loss += lam * ((pred_epsilon - epsilon) * mu).sum()
                loss.backward()

            optimizer.step()


        return mu

