
import torch
import torch
import numpy as np
import os
from time import perf_counter
import denflow.utils as utils
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from denflow.utils_closed_form import get_full_velocity_field_optimized_


def infiniteloop(dataloader):
    while True:
        for x, y in iter(dataloader):
            yield x


class INVESTIGATE_APPROX(object):

    def __init__(self, model, device, args):
        self.device = device
        self.args = args
        self.model = model  # .to(device)
        self.method = args.method

    def compute_approx(self, train_dataloader):
        n_points = 100

        res = np.zeros(n_points)
        num_batches = self.args.max_batch
        for i, t in enumerate(np.linspace(0, 0.99, n_points)):
            for j, batch in enumerate(range(num_batches)):
                x1 = next(iter(train_dataloader))[0].to(self.device)
                x0 = torch.randn_like(x1)
                t_ = t * torch.ones(x1.shape[0], device=self.device)
                xt = (1 - t) * x0 + t * x1
                xt_flatten = torch.flatten(xt, start_dim=1)
                x1_flatten = torch.flatten(x1, start_dim=1)
                utot = get_full_velocity_field_optimized_(
                    t_[:, None], xt_flatten, x1_flatten)
                vt = self.model.get_velocity(xt, t_).detach()
                vt_flatten = torch.flatten(vt, start_dim=1)
                fm_loss = torch.sum((vt_flatten - utot) ** 2, dim=1)
                if j == 0:
                    losses = torch.zeros_like(fm_loss)
                losses = losses + fm_loss
            res[i] = losses.mean() / num_batches / utot.shape[1]
            print(f"t {t}  , FM loss {res[i]:.3f}")
        np.save(os.path.join(self.args.save_path_ip, "res.npy"), res)

    def run_method(self, data_loaders, degradation, sigma_noise, H_funcs=None):

        # Construct the save path for results
        folder = utils.get_save_path_ip(self.args.dict_cfg_method)
        self.args.save_path_ip = os.path.join(self.args.save_path, folder)

        # Create the directory if it doesn't exist
        print(self.args.save_path_ip)
        os.makedirs(self.args.save_path_ip, exist_ok=True)

        # Solve the inverse problem
        self.compute_approx(
            data_loaders[self.args.eval_split])
