
import torch
import torch
import numpy as np
import os
import tqdm
from time import perf_counter
import denflow.utils as utils
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from deepinv.loss.regularisers import JacobianSpectralNorm


def infiniteloop(dataloader):
    while True:
        for x, y in iter(dataloader):
            yield x


class INVESTIGATE_LIP(object):

    def __init__(self, model, device, args):
        self.device = device
        self.args = args
        self.model = model  # .to(device)
        self.method = args.method
        self.use_denoiser = self.args.use_denoiser
        self.use_closed_form = (self.args.train_object == 'optimal_denoiser')
        self.jacobian_spectral_norm =  JacobianSpectralNorm(max_iter=10, tol=1e-5, eval_mode=True, verbose=True, reduction=None)

    # def jacobian_spectral_norm(self, output, x, detach=True):
    #     n_dims = x.dim()
    #     u = torch.randn_like(x)
    #     # Normalize each batch element
    #     u = u / torch.norm(u.flatten(start_dim=1, end_dim=-1), p=2, dim=-1).view(
    #         -1, *[1] * (n_dims - 1)
    #     )

    #     zold = torch.randn_like(u)
    #     max_iter = 50
    #     for it in tqdm.trange(max_iter):
    #         # Double backward trick. From https://gist.github.com/apaszke/c7257ac04cb8debb82221764f6d117ad
    #         w = torch.ones_like(output, requires_grad=True)
    #         v = torch.autograd.grad(
    #             torch.autograd.grad(output, x, w, create_graph=True),
    #             w,
    #             u,
    #             create_graph=True,
    #         )[
    #             0
    #         ]  # v = A(u)

    #         (v,) = torch.autograd.grad(output, x, v, retain_graph=True, create_graph=True)

    #         # multiply corresponding batch elements
    #         z = (
    #             torch.einsum("bn,bn->b", u.flatten(start_dim=1, end_dim=-1),
    #                          v.flatten(start_dim=1, end_dim=-1)) / torch.norm(u.flatten(start_dim=1, end_dim=-1), p=2, dim=-1) ** 2)

    #         if it > 0:
    #             rel_var = torch.norm(z - zold)
    #             if rel_var < 1e-4:
    #                 print(
    #                     "Power iteration converged at iteration: ",
    #                     it,
    #                     ", val: ",
    #                     z.sqrt().tolist(),
    #                     ", relvar :",
    #                     rel_var.item(),
    #                 )
    #                 break
    #         zold = z.detach().clone()

    #         u = v / torch.norm(v.flatten(start_dim=1, end_dim=-1), p=2, dim=-1).view(
    #             -1, *[1] * (n_dims - 1)
    #         )

    #         if detach:
    #             w.detach_()
    #             v.detach_()
    #             u.detach_()

    #     return z.view(-1).sqrt()

    def compute_lip_on_traj(self, batch_size):

        num_channels = self.args.num_channels
        res = self.args.dim_image
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, self.args.sampling_steps, device=self.device)

        integration_steps = len(time_points)
        print(time_points[0], time_points[-1], integration_steps)
        if res == 2:
            x0 = torch.randn(batch_size, num_channels,
                             res, 1, device=self.device)
        else:
            x0 = torch.randn(batch_size, num_channels,
                             res, res, device=self.device)

        traj = torch.zeros(
            integration_steps, *x0.shape)
        jac_norms = torch.zeros(
            integration_steps, batch_size)
        xt = x0.clone()
        for i, t in enumerate(time_points):
            t_ = t * torch.ones(len(xt), device=self.device)
            xt = xt.clone().requires_grad_(True)
            if self.use_denoiser:
                d_t = self.model.get_denoiser(xt, t_)
                jac_norm = self.jacobian_spectral_norm(d_t, xt)
                # self.jacobian_spectral_norm(
                #     d_t, xt, detach=not self.use_closed_form)
            else:
                v_t = self.model.get_velocity(xt, t_)
                jac_norm =  self.jacobian_spectral_norm(v_t, xt)
                # jac_norm = self.jacobian_spectral_norm(
                #     v_t, xt, detach=not self.use_closed_form)

            jac_norms[i, :] = jac_norm.detach()
            with torch.no_grad():
                if self.use_denoiser:
                    v_t = self.model.get_velocity(xt, t_)
                xt = xt + v_t / integration_steps
                xt = xt.detach()
                traj[i, :] = xt
        xt = traj[-1, :]

        return jac_norms

    def run_method(self, data_loaders, degradation, sigma_noise, H_funcs=None):

        # Construct the save path for results
        folder = utils.get_save_path_ip(self.args.dict_cfg_method)
        self.args.save_path_ip = os.path.join(self.args.save_path, folder)

        # Create the directory if it doesn't exist
        print(self.args.save_path_ip)
        os.makedirs(self.args.save_path_ip, exist_ok=True)

        # Solve the inverse problem
        batch_size = self.args.batch_size_gen
        max_batch = self.args.num_images // self.args.batch_size_gen + 1
        for batch in range(max_batch):
            jacobian_norms_on_traj = self.compute_lip_on_traj(
                batch_size)
            torch.save(jacobian_norms_on_traj, self.args.save_path_ip +
                       f"/jacobian_norm_batch{batch}.pt")
