import torch
import os
import tqdm
import time
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from torchvision import transforms
from fld.features.InceptionFeatureExtractor import InceptionFeatureExtractor
from fld.metrics.FID import FID
from torchvision.datasets.cifar import CIFAR10
from torchmetrics.functional.image import peak_signal_noise_ratio as PSNR

from denflow.utils import plot_test_data, plot_paths, compute_W2, compute_path_energy
from denflow.dataloaders import CelebADataset


EXTRACTORS = {"inception": InceptionFeatureExtractor}

img_dir = './data/celeba/img_align_celeba/'
partition_csv = './data/celeba/list_eval_partition.csv'


class NOISE_PERTURB_SAMPLER(object):
    def __init__(self, model, device, args):
        self.device = device
        self.args = args
        self.model = model
        self.method = args.method
        self.ft_extractor = InceptionFeatureExtractor(save_path="features")
        self.t_interval_min = self.args.time_interval[0]
        self.t_interval_max = self.args.time_interval[1]
        print(self.t_interval_min, self.t_interval_max)
        self.noise_perturb_ratio = self.args.noise_perturb_ratio
        self.perturb_style = self.args.perturb_style
        self.noise_perturb_computation = self.args.noise_perturb_computation

    def get_perturb(self, xt, t_ ):
        if self.perturb_style == 'gaussian':
            perturb = torch.randn_like(xt)
        elif self.perturb_style == 'constant':
            perturb = torch.ones_like(xt)
        elif self.perturb_style == 'constant_neg':
            perturb = -torch.ones_like(xt)
        elif self.perturb_style == 'id_average':
            perturb = xt - self.model.get_denoiser(xt, t_)
        elif self.perturb_style == "chessboard":
            perturb = torch.ones_like(xt)
            perturb[:,:,::2,::2]=-1
            perturb[:,:,1::2,1::2]=-1
        elif self.perturb_style == "chessboard_4" or  "chessboard_8" or  "chessboard_16" :
            k = int(self.perturb_style.split('_')[-1])
            perturb = torch.ones_like(xt)
            unfold = torch.nn.Unfold(kernel_size=(k,k), stride = k)
            fold = torch.nn.Fold(xt.shape[2:],kernel_size=(k,k), stride = k)
            patches = unfold(perturb)
            img_dim = xt.shape[-1]
            nH = (img_dim - k) // k + 1
            nW = (img_dim - k) // k + 1
            rows, cols = torch.meshgrid(torch.arange(nH), torch.arange(nW), indexing="ij")
            mask = ((rows + cols) % 2).reshape(-1)
            mask = mask.to(self.device)
            patches = patches * (1 - 2*mask).view(1, 1, -1)
            perturb = fold(patches)
        else:
            ValueError("perturb unknown")
        return perturb


    def generate_samples(self, dict_sigmas=0., batch_size=None):
        num_channels = self.args.num_channels
        res = self.args.dim_image
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, self.args.sampling_steps, device=self.device)

        integration_steps = len(time_points)
        if res == 2:
            x0 = torch.randn(batch_size, num_channels,
                             res, 1, device=self.device)
        else:
            x0 = torch.randn(batch_size, num_channels,
                             res, res, device=self.device)

        with torch.no_grad():
            traj = torch.zeros(
                integration_steps, *x0.shape)
            xt = x0.clone()
            for i, t in enumerate(time_points):
                t_ = t * torch.ones(len(xt), device=self.device)
                if t >= self.t_interval_min and t < self.t_interval_max:

                    closest_key = min(dict_sigmas.keys(),
                                      key=lambda k: abs(float(k) - t.item()))

                    sigma_perturb = float(dict_sigmas[closest_key])
                    perturb = self.get_perturb(xt, t_)
                    v_t = self.model.get_velocity(
                        xt, t_) + (sigma_perturb / torch.clamp((1-t), min=1e-3)) * perturb
                    xt = xt + v_t / integration_steps
                    traj[i, :] = xt.detach()
                else:
                    v_t = self.model.get_velocity(xt, t_)
                    xt = xt + v_t / integration_steps
                    traj[i, :] = xt.detach()
        xt = traj[-1, :]
        return xt, traj

    def compute_fid(self):
        with torch.no_grad():
            source_dir = os.path.join(self.args.save_path, "final")
            print(source_dir)
            # Now compute features only on the 3000
            gen_feat = self.ft_extractor.get_dir_features(source_dir)
            print(gen_feat.shape)
            if self.args.dataset == 'celeba' or self.args.dataset == "celeba64":
                test_feat = self.ft_extractor.get_features(CelebADataset(
                    img_dir, partition_csv, partition=2, transform=transforms.Compose([transforms.CenterCrop(178), transforms.Resize([self.args.dim_image, self.args.dim_image]),])), name=f"celeba{self.args.dim_image}_test")
            elif self.args.dataset == "cifar10":
                test_feat = self.ft_extractor.get_features(
                    CIFAR10(train=False, root="data", download=True), name="cifar10_test")
            else:
                ValueError("Dataset not available to compute FID")
            fid_val = FID().compute_metric(test_feat,
                                           None, gen_feat)
            #    None, gen_feat[:self.args.num_images_fid])
        print("FID", fid_val, self.args.num_images_fid)
        return fid_val

    def sample2D(self, test_loader):
        batch_size = self.args.batch_size_gen
        gen_samples, traj = self.generate_samples(batch_size=batch_size)
        fig, ax = plt.subplots()
        plot_test_data(ax, next(iter(test_loader))[0])
        plot_paths(ax, traj)
        os.makedirs(self.args.save_path +
                    f"/final", exist_ok=True)
        fig.tight_layout()
        # ax.set_xlim(-8, 10)
        # ax.set_ylim(-6, 6)
        ax.set_xlim(-11, 11)
        ax.set_ylim(-11, 11)
        plt.savefig(self.args.save_path +
                    f"/final/{self.args.loss_type}_{self.args.class_object}_{batch_size}.png", dpi=300)

    def compute_perturb_norms(self, test_loader, dict_sigmas):
        batch_size = self.args.batch_size_gen
        source_dir = os.path.join(self.args.save_path, "norms")
        _, traj_one_batch = self.generate_samples(
                    batch_size=batch_size, dict_sigmas=dict_sigmas)
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, self.args.sampling_steps, device=self.device)
        integration_steps = len(time_points)
        norms_perturb_Dt_on_traj = 0
        norms_perturb_xt_on_traj = 0
        norms_perturb_Dt = 0
        norms_perturb_xt = 0
        counter_perturb = 0
        for i, t in enumerate(time_points):

            if t > self.t_interval_min and t <= self.t_interval_max and self.t_interval_min != self.t_interval_max: #change the >, >= to look at t+dt on perturbed interval
                counter_perturb +=1
                x_t = traj_one_batch[i,:].clone().to(self.device)
                x_t_minus_1 =  traj_one_batch[i-1,:].clone().to(self.device)
                t_minus_1 = time_points[i-1] * torch.ones(len(x_t_minus_1), device=self.device)
                t_ = time_points[i] * torch.ones(len(x_t_minus_1), device=self.device)
                v_t = self.model.get_velocity(
                        x_t_minus_1, t_minus_1).detach()
                x_t_unperturbed = x_t_minus_1 + v_t / integration_steps

                diff_Dt_on_traj = self.model.get_denoiser(x_t, t_).detach() -self.model.get_denoiser(x_t_unperturbed, t_).detach()
                norm_Dt_on_traj = torch.sqrt(torch.sum(diff_Dt_on_traj**2, dim=(1,2,3))).mean()
                norms_perturb_Dt_on_traj += norm_Dt_on_traj.item()
                diff_xt_on_traj = x_t - x_t_unperturbed
                norm_xt_on_traj = torch.sqrt(torch.sum(diff_xt_on_traj**2, dim=(1,2,3))).mean()
                norms_perturb_xt_on_traj += norm_xt_on_traj
                closest_key = min(dict_sigmas.keys(),
                                      key=lambda k: abs(float(k) - t.item()))

                sigma_perturb = float(dict_sigmas[closest_key])
                perturb = self.get_perturb(x_t_minus_1, t_minus_1)
                x_t_perturb = x_t_minus_1 + sigma_perturb * perturb
                diff_Dt = self.model.get_denoiser(x_t_perturb, t_minus_1).detach() - self.model.get_denoiser(x_t_minus_1, t_minus_1).detach()
                norm_Dt =torch.sqrt(torch.sum(diff_Dt**2, dim=(1,2,3))).mean()
                norm_xt = torch.sqrt(torch.sum((x_t_perturb-x_t_minus_1)**2, dim=(1,2,3))).mean()
                norms_perturb_Dt += norm_Dt.item()
                norms_perturb_xt += norm_xt.item()
                del diff_Dt, diff_Dt_on_traj, diff_xt_on_traj


        norms_perturb_xt_on_traj /= counter_perturb
        norms_perturb_xt_on_traj /= counter_perturb
        norms_perturb_Dt /= counter_perturb
        norms_perturb_xt /= counter_perturb

        norms_filename = os.path.join(self.args.save_path, 'computed_norms.txt')
        with open(norms_filename, 'a') as file:
            file.write(f'PERTURB_ON_TRAJ interval {self.t_interval_min}-{self.t_interval_max} norm_Dt {norms_perturb_Dt_on_traj} norm_xt {norms_perturb_xt_on_traj} \n')
            file.write(f'PERTURB_XT interval {self.t_interval_min}-{self.t_interval_max} norm_Dt {norms_perturb_Dt} norm_xt {norms_perturb_xt} \n')


    def sample(self, test_loader, dict_sigmas):
        batch_size = self.args.batch_size_gen
        num_images_fid = self.args.num_images_fid

        source_dir = os.path.join(self.args.save_path, "final")
        num_existing_imgs = 0
        if os.path.isdir(source_dir):
            num_existing_imgs = len([f for f in os.listdir(source_dir)
                                     if os.path.isfile(os.path.join(source_dir, f))])
        max_batch = num_images_fid // batch_size + 1
        if num_existing_imgs < num_images_fid:
            if num_existing_imgs > 0:
                start_batch = num_existing_imgs // batch_size
            else:
                start_batch = 0
            for batch in tqdm.trange(start_batch, max_batch+1):
                print(f"Batch {batch} with size {batch_size}")
                xt, __ = self.generate_samples(
                    batch_size=batch_size, dict_sigmas=dict_sigmas)
                rescaled_images = xt.view(
                    [-1, 3, self.args.dim_image, self.args.dim_image])
                rescaled_images = rescaled_images / 2 + 0.5

                os.makedirs(self.args.save_path +
                            f"/final", exist_ok=True)
                for i, img in tqdm.tqdm(enumerate(rescaled_images)):
                    save_image(img, self.args.save_path +
                               f"/final/{self.args.loss_type}_{self.args.class_object}_image_{batch_size * batch + i}.png")
        # number_images = batch_size * self.args.max_batch - 1
        fid_val = self.compute_fid()
        fid_filename = os.path.join(self.args.save_path, 'fid.txt')
        with open(fid_filename, 'a') as file:
            file.write(
                f'Sampling steps {self.args.sampling_steps} t0 {self.model.time_start} t_end {self.model.time_end} FID{self.args.num_images_fid//1000}k {fid_val}\n')

    def plot_trajectories(self,dict_sigmas, batch_size=16, n_samples=10):
        """
        Plot trajectories of the first n_samples in the batch.
        Each sample is shown at t = 0.1, ..., 0.9, 1.0 (10 time points).
        """
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, int(self.args.sampling_steps) + 1, device=self.device)

        print("self.model.time_start", self.model.time_start)
        for k in range(10):
            with torch.no_grad():
                x0 = torch.randn(1, 3,
                                 self.args.dim_image, self.args.dim_image, device=self.device)
                traj = []
                times = []
                xt = x0

                time_points = torch.linspace(
            self.model.time_start, self.model.time_end, self.args.sampling_steps, device=self.device)

                integration_steps = len(time_points)
                xt = x0.clone()
                for i, t in enumerate(time_points):
                    t_ = t * torch.ones(len(xt), device=self.device)
                    if t >= self.t_interval_min and t < self.t_interval_max:

                        closest_key = min(dict_sigmas.keys(),
                                        key=lambda k: abs(float(k) - t.item()))

                        sigma_perturb = float(dict_sigmas[closest_key])
                        perturb = self.get_perturb(xt, t_)
                        v_t = self.model.get_velocity(
                            xt, t_) + (sigma_perturb / torch.clamp((1-t), min=1e-3)) * perturb
                        xt = xt + v_t / integration_steps
                    else:
                        v_t = self.model.get_velocity(xt, t_)
                        xt = xt + v_t / integration_steps
                    if i % 10 == 0:
                        print(i, t)
                        traj.append(xt.detach() / 2 + 0.5)
                        times.append(t)
                traj.append(xt.detach() / 2 + 0.5)
                times.append(t)
                fig, axes = plt.subplots(1, 11, figsize=(20, 3))
                for j, ax in enumerate(axes):
                    im = traj[j].detach().cpu()[0]
                    ax.imshow(im.permute(1, 2, 0))
                    ax.axis("off")
                    ax.set_title("t={:2.2f}".format(times[j]))
                fig.tight_layout()

                out_path = os.path.join(self.args.save_path,
                                        f"trajectory_sample{k}_euler.png")
                fig.savefig(out_path, dpi=200)
                plt.close(fig)
                print(f"Saved trajectory plot to {out_path}")

    def compute_sigmas(self, loader):
        sigmas_tests = torch.linspace(0, 10, 10000, device=self.device)
        x_1 = next(iter(loader))[0].to(self.device)
        x_0 = torch.randn_like(x_1)
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, self.args.sampling_steps, device=self.device)
        if not os.path.exists(os.path.join(self.args.save_path, "computed_sigmas_time.txt")):
            for t in time_points:
                x_t = t * x_1 + (1-t) * x_0
                D_t = self.model.get_denoiser(
                    x_t, t * torch.ones(len(x_t), device=x_t.device))
                perturb = self.get_perturb(D_t, t * torch.ones(len(x_t), device=x_t.device))
                perturb = perturb[None,:,:,:,:]
                D_t_tilde = D_t[None, :, :, :, :] + sigmas_tests[:, None, None,
                                                                 None, None] * perturb
                psnr_test = PSNR(D_t, x_1, data_range=2,  dim=(1, 2, 3))
                psnr = PSNR(D_t_tilde, x_1[None, :, :, :, :],
                            data_range=2, reduction='none', dim=(2, 3, 4)).mean(dim=1)
                sigma_idx = torch.argmin(
                    torch.abs(psnr - (1-self.noise_perturb_ratio) * psnr[0]))
                sigma_choosen = sigmas_tests[sigma_idx]

                sigmas_filename = os.path.join(
                    self.args.save_path, 'computed_sigmas_time.txt')
                with open(sigmas_filename, 'a') as file:
                    file.write(f't {t} sigma {sigma_choosen} \n')
                if self.perturb_style == 'gaussian':
                    assert torch.allclose((1-self.noise_perturb_ratio) * psnr_test, PSNR(
                        D_t + torch.randn_like(D_t) * sigma_choosen, x_1, data_range=2, dim=(1, 2, 3)), rtol=1e-2)
                elif self.perturb_style == 'constant':
                    assert torch.allclose((1-self.noise_perturb_ratio) * psnr_test, PSNR(
                        D_t + torch.ones_like(D_t) * sigma_choosen, x_1, data_range=2, dim=(1, 2, 3)), rtol=1e-2)
                elif self.perturb_style == 'constant_neg':
                    assert torch.allclose((1-self.noise_perturb_ratio) * psnr_test, PSNR(
                        D_t - torch.ones_like(D_t) * sigma_choosen, x_1, data_range=2, dim=(1, 2, 3)), rtol=1e-2)
                elif self.perturb_style == 'chessboard':
                    perturb =  torch.ones_like(D_t)
                    perturb[:,:,::2,::2] = -1
                    perturb[:,:,1::2,1::2] = -1
                    assert torch.allclose((1-self.noise_perturb_ratio) * psnr_test, PSNR(
                        D_t + perturb * sigma_choosen, x_1, data_range=2, dim=(1, 2, 3)), rtol=1e-2)
                elif self.perturb_style == 'id_average':
                    assert torch.allclose((1-self.noise_perturb_ratio) * psnr_test, PSNR(
                        (1 - sigma_choosen) * D_t + x_t * sigma_choosen, x_1, data_range=2, dim=(1, 2, 3)), rtol=1e-2)
                del psnr
        dict_sigmas = {}
        with open(os.path.join(self.args.save_path, "computed_sigmas_time.txt")) as file:
            for line in file:
                content = line.strip().split()
                time_idx = content.index("t")
                sigma_idx = content.index("sigma")
                dict_sigmas[content[time_idx+1]] = content[sigma_idx+1]
        return dict_sigmas

    def run_method(self, data_loaders, degradation, sigma_noise, H_funcs=None):

        # Create the directory if it doesn't exist
        self.args.save_path = self.args.save_path + \
            f"/euler_{self.args.sampling_steps}/noise_pertub_ratio_{self.args.noise_perturb_ratio}/perturb_style_{self.perturb_style}/"
        print(self.args.save_path)
        os.makedirs(self.args.save_path, exist_ok=True)

        dict_sigmas = self.compute_sigmas(data_loaders[self.args.eval_split])

        self.args.save_path = self.args.save_path + \
            f"t_interval_min_{self.t_interval_min}/t_interval_max_{self.t_interval_max}/"

        os.makedirs(self.args.save_path, exist_ok=True)

        if self.args.dim_image == 2:
            self.args.batch_size_gen = 400
            self.sample2D(data_loaders[self.args.eval_split])
        else:
            # self.plot_trajectories(dict_sigmas,
            #     batch_size=self.args.batch_size_gen, n_samples=10)

            # Solve the inverse problem

            self.sample(
                data_loaders[self.args.eval_split], dict_sigmas)
            self.compute_perturb_norms(data_loaders[self.args.eval_split], dict_sigmas)
