import torch
import os
import types
import tqdm
from torchdiffeq import odeint
from torchvision.utils import save_image
from torchvision import transforms
from fld.features.InceptionFeatureExtractor import InceptionFeatureExtractor
from fld.metrics.FID import FID
from denflow.dataloaders import CelebADataset
from torchvision.datasets.cifar import CIFAR10
from denflow.utils import plot_test_data, plot_paths, compute_W2, compute_path_energy
from matplotlib import pyplot as plt
from denflow.utils import define_model, load_model_runid
from denflow.train_denoisers import GENERAL_DENOISER


EXTRACTORS = {"inception": InceptionFeatureExtractor}

img_dir = './data/celeba/img_align_celeba/'
partition_csv = './data/celeba/list_eval_partition.csv'


class torch_model_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        t_ = t * torch.ones(len(x), device=x.device)
        return self.model.get_velocity(x, t_).detach()


class TWO_MODELS_SAMPLER(object):
    def __init__(self, model, device, args):
        self.device = device
        self.args = args
        self.model = model  # .to(device)
        self.method = args.method
        self.ft_extractor = InceptionFeatureExtractor(save_path="features")
        self.run_id_FM = self.args.base_run_id
        (model_FM, state) = define_model(args)
        model_fold = f"denoiser{'_ema' if self.args.use_ema else ''}_final"
        model_FM = load_model_runid(
            model_FM, model_fold=model_fold, run_id=self.run_id_FM, device=device)
        model_FM.eval()
        self.denoiser_FM = GENERAL_DENOISER(
            model_FM, f'{self.args.base_loss_type}', f'{self.args.base_class_object}', device, args)
        self.t_sep = self.args.t_sep

    def generate_samples(self, batch_size=None):
        num_channels = self.args.num_channels
        res = self.args.dim_image

        # Define intervals
        t_start = self.model.time_start
        t_inf = self.args.t_inf
        t_sup = self.args.t_sup
        t_end = 1.0

        # Number of steps per interval (proportional to length)
        n_1 = int(self.args.sampling_steps * (t_inf - t_start))
        n_3 = int(self.args.sampling_steps * (t_end - t_sup))
        n_2 = self.args.sampling_steps - n_1 - n_3  # remaining steps in [t_inf, t_sup]

        # Build time grids
        if n_1 > 0:
            time_points_1 = torch.linspace(t_start, t_inf, n_1, device=self.device)
        time_points_2 = torch.linspace(t_inf, t_sup, n_2, device=self.device)
        if n_3 > 0:
            time_points_3 = torch.linspace(t_sup, t_end, n_3, device=self.device)

        # Sample x0
        if res == 2:
            x0 = torch.randn(batch_size, num_channels, res, 1, device=self.device)
        else:
            x0 = torch.randn(batch_size, num_channels, res, res, device=self.device)

        with torch.no_grad():
            xt = x0
            # First interval with FM denoiser
            if n_1 > 0:
                wrapper_model = torch_model_wrapper(self.denoiser_FM)
                trajectory = odeint(
                    wrapper_model, xt, time_points_1, rtol=1e-4, atol=1e-4,
                    method=self.args.integration_method
                )
                xt = trajectory[-1]

            # Middle interval with target model
            wrapper_model = torch_model_wrapper(self.model)
            trajectory = odeint(
                wrapper_model, xt, time_points_2, rtol=1e-4, atol=1e-4,
                method=self.args.integration_method
            )
            xt = trajectory[-1]

            # Last interval with FM denoiser
            if n_3 > 0:
                wrapper_model = torch_model_wrapper(self.denoiser_FM)
                trajectory = odeint(
                    wrapper_model, xt, time_points_3, rtol=1e-4, atol=1e-4,
                    method=self.args.integration_method
                )
                xt = trajectory[-1]

        return xt, trajectory


    def compute_fid(self, dataset, number_images):
        with torch.no_grad():
            source_dir = os.path.join(self.args.save_path, "final")
            # Now compute features only on the 3000
            gen_feat = self.ft_extractor.get_dir_features(source_dir)
            if self.args.dataset == 'celeba' or self.args.dataset == "celeba64":
                test_feat = self.ft_extractor.get_features(CelebADataset(
                    img_dir, partition_csv, partition=2, transform=transforms.Compose([transforms.CenterCrop(178), transforms.Resize([self.args.dim_image, self.args.dim_image]),])), name=f"celeba{self.args.dim_image}_test")
            elif self.args.dataset == "cifar10":
                test_feat = self.ft_extractor.get_features(
                    CIFAR10(train=False, root="data", download=True), name="cifar10_test")
            else:
                ValueError("Dataset not available to compute FID")
            fid_val = FID().compute_metric(test_feat,
                                           None, gen_feat[:self.args.num_images_fid])
        return fid_val

    def sample(self, test_loader):
        batch_size = self.args.batch_size_gen
        num_images_fid = self.args.num_images_fid

        source_dir = os.path.join(self.args.save_path, "final")
        num_existing_imgs = 0
        if os.path.isdir(source_dir):
            num_existing_imgs = len([f for f in os.listdir(source_dir)
                                     if os.path.isfile(os.path.join(source_dir, f))])
        max_batch = num_images_fid // batch_size + 1
        if num_existing_imgs < num_images_fid:
            if num_existing_imgs > 0:
                start_batch = num_existing_imgs // batch_size
            else:
                start_batch = 0
            for batch in tqdm.trange(start_batch, max_batch+1):
                print(f"Batch {batch} with size {batch_size}")
                xt, __ = self.generate_samples(batch_size=batch_size)
                rescaled_images = xt.view(
                    [-1, 3, self.args.dim_image, self.args.dim_image])
                rescaled_images = rescaled_images / 2 + 0.5

                os.makedirs(self.args.save_path +
                            f"/final", exist_ok=True)
                for i, img in tqdm.tqdm(enumerate(rescaled_images)):
                    save_image(img, self.args.save_path +
                               f"/final/{self.args.loss_type}_{self.args.class_object}_image_{batch_size * batch + i}.png")
        # number_images = batch_size * self.args.max_batch - 1
        fid_val = self.compute_fid(test_loader, num_images_fid)
        fid_filename = os.path.join(self.args.save_path, 'fid.txt')
        with open(fid_filename, 'a') as file:
            file.write(
                f'Sampling steps {self.args.sampling_steps} t0 {self.model.time_start} t_end {self.model.time_end} FID{self.args.num_images_fid//1000}k {fid_val}\n')

   

    def run_method(self, data_loaders, degradation, sigma_noise, H_funcs=None):

        # Create the directory if it doesn't exist
        self.args.save_path = self.args.save_path + \
            f"/euler_{self.args.sampling_steps}/"
        print(self.args.save_path)
        os.makedirs(self.args.save_path, exist_ok=True)

        # Solve the inverse problem
        self.sample(
            data_loaders[self.args.eval_split])
