import torch
import torch
import numpy as np
import os
from time import perf_counter
import tqdm
import denflow.utils as utils
from torchvision.utils import save_image
from torchmetrics.functional.image import peak_signal_noise_ratio as PSNR


class NOISE_PERTURB_PSNR(object):

    def __init__(self, model, device, args):
        self.device = device
        self.args = args
        self.model = model  # .to(device)
        self.method = args.method
        self.noise_perturb = self.args.noise_perturb

    def measure_psnr(self, test_loader):
        sigma_noises_test = self.args.sigmas_test
        psnr_filename = os.path.join(self.args.save_path_ip, 'psnrs.txt')

        for i, sigma in tqdm.tqdm(enumerate(sigma_noises_test)):
            psnr = torch.zeros(self.args.max_batch * self.args.batch_size_ip)
            if self.args.save_results:
                os.makedirs(self.args.save_path +
                            f"/sigma_{sigma}", exist_ok=True)
            loader = iter(test_loader)
            for batch in range(self.args.max_batch):

                (clean_img, labels) = next(loader)
                batch_size = clean_img.shape[0]
                self.args.batch = batch

                if self.args.noise_style == "interpolation":

                    noisy_img = (1-sigma) * clean_img.clone().to(self.device)
                    torch.manual_seed(batch)
                    noisy_img += torch.randn_like(noisy_img) * sigma

                    noisy_img, clean_img = noisy_img.to(
                        self.device), clean_img.to('cpu')
                    x = noisy_img.clone()
                    t = 1. - sigma
                    x = self.model.get_denoiser(
                        x, t * torch.ones(len(x), device=x.device)) + (1 - t) * self.noise_perturb * torch.randn_like(x)
                    vt = self.model.get_velocity(
                        x, t * torch.ones(len(x), device=x.device)) 
                    print(torch.norm(vt) / torch.norm(self.noise_perturb * torch.randn_like(x)) )


                elif self.args.noise_style == "classical":
                    torch.manual_seed(batch)
                    sigma_new = sigma / (1 - sigma)
                    noisy_img = clean_img.clone().to(self.device)
                    noisy_img += torch.randn_like(noisy_img) * sigma_new

                    noisy_img, clean_img = noisy_img.to(
                        self.device), clean_img.to('cpu')
                    x = noisy_img.clone()
                    x = self.model.get_denoiser_classical(
                        x, sigma_new * torch.ones(len(x), device=x.device))

                restored_img = x.detach().clone()
                restored_img = restored_img / 2 + 0.5

                for i, img in tqdm.tqdm(enumerate(restored_img)):
                    noisy_img_plot = noisy_img[i].detach()
                    noisy_img_plot = noisy_img_plot/2 + 0.5
                    if self.args.save_results:
                        save_image([noisy_img_plot, img], self.args.save_path +
                                   f"/sigma_{sigma}/image_{batch * batch_size + i}.png")
                    psnr[batch * batch_size + i] = PSNR(clean_img[i] / 2 + 0.5,
                                                        img.detach().cpu(), data_range=1)

            with open(psnr_filename, 'a') as file:
                file.write(
                    f'Sigma {sigma} N_images {len(psnr)} PSNR mean {psnr.mean().item()} PSNR std {psnr.std().item()}\n')

    def measure_mse(self, test_loader):
        loader = iter(test_loader)
        sigma_noises_test = self.args.sigmas_test
        mse_filename = os.path.join(self.args.save_path_ip, 'mses.txt')

        for i, sigma in tqdm.tqdm(enumerate(sigma_noises_test)):
            mse = torch.zeros(self.args.max_batch * self.args.batch_size_ip)
            if self.args.save_results:
                os.makedirs(self.args.save_path +
                            f"/sigma_{sigma}", exist_ok=True)
            for batch in range(self.args.max_batch):

                (clean_img, labels) = next(loader)
                batch_size = clean_img.shape[0]
                self.args.batch = batch

                if self.args.noise_style == "interpolation":

                    noisy_img = (1-sigma) * clean_img.clone().to(self.device)
                    torch.manual_seed(batch)
                    noisy_img += torch.randn_like(noisy_img) * sigma

                    noisy_img, clean_img = noisy_img.to(
                        self.device), clean_img.to(self.device)
                    x = noisy_img.clone()
                    t = 1-sigma
                    x = self.model.get_denoiser(
                        x, t * torch.ones(len(x), device=x.device))

                elif self.args.noise_style == "classical":
                    torch.manual_seed(batch)
                    sigma_new = sigma / (1 - sigma)
                    noisy_img = clean_img.clone().to(self.device)
                    noisy_img += torch.randn_like(noisy_img) * sigma_new

                    noisy_img, clean_img = noisy_img.to(
                        self.device), clean_img.to(self.device)
                    x = noisy_img.clone()
                    x = self.model.get_denoiser_classical(
                        x, sigma_new * torch.ones(len(x), device=x.device))

                restored_img = x.detach().clone()

                for i, img in tqdm.tqdm(enumerate(restored_img)):
                    mse[batch * batch_size + i] = torch.sum((clean_img[i] - img)**2)

            with open(mse_filename, 'a') as file:
                file.write(
                    f'Sigma {sigma} N_images {len(mse)} MSE mean {mse.mean().item()} MSE std {mse.std().item()}\n')

    def run_method(self, data_loaders, degradation, sigma_noise, H_funcs=None):

        # Construct the save path for results
        folder = utils.get_save_path_ip(self.args.dict_cfg_method)
        self.args.save_path_ip = os.path.join(self.args.save_path, folder)

        # Create the directory if it doesn't exist
        print(self.args.save_path_ip)
        os.makedirs(self.args.save_path_ip, exist_ok=True)

        # Compute psnr
        if self.args.dim_image == 2:
            self.measure_mse(data_loaders[self.args.eval_split])
        else:
            self.measure_psnr(
                data_loaders[self.args.eval_split])
