import torch
import os
import types
import tqdm
from torchdiffeq import odeint
from torchvision.utils import save_image
from torchvision import transforms
from fld.features.InceptionFeatureExtractor import InceptionFeatureExtractor
from fld.features.DINOv2FeatureExtractor import DINOv2FeatureExtractor
from fld.metrics.FID import FID
from denflow.dataloaders import CelebADataset
from torchvision.datasets.cifar import CIFAR10
from denflow.utils import plot_test_data, plot_paths, compute_W2, compute_path_energy
from matplotlib import pyplot as plt

EXTRACTORS = {"inception": InceptionFeatureExtractor, "dino": DINOv2FeatureExtractor}

img_dir = './data/celeba/img_align_celeba/'
partition_csv = './data/celeba/list_eval_partition.csv'


class torch_model_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        t_ = t * torch.ones(len(x), device=x.device)
        return self.model.get_velocity(x, t_).detach()


class BASIC_SAMPLER(object):
    def __init__(self, model, device, args):
        self.device = device
        self.args = args
        self.model = model
        self.method = args.method
        self.ft_extractor_name = self.args.ft_extractor
        self.ft_extractor = EXTRACTORS[self.ft_extractor_name](save_path="features")

    def generate_samples(self, batch_size=None):
        t0 = 0.0
        integration_method = self.args.integration_method
        integration_steps = self.args.sampling_steps
        num_channels = self.args.num_channels
        res = self.args.dim_image
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, int(integration_steps), device=self.device)

        # Sample x0
        if res == 2:
            x0 = torch.randn(batch_size, num_channels,
                             res, 1, device='cpu').to(self.device)
        else:
            x0 = torch.randn(batch_size, num_channels,
                             res, res, device="cpu").to(self.device)

        if integration_method == "euler_naive":
            with torch.no_grad():
                traj = torch.zeros((len(time_points), *x0.shape))
                xt = x0
                for i, t in enumerate(time_points):
                    t_ = t * torch.ones(len(xt), device=self.device)
                    v_t = self.model.get_velocity(xt, t_)
                    xt = xt + v_t / integration_steps
                    traj[i, :] = xt.detach()
        else:
            wrapper_model = torch_model_wrapper(
                self.model)
            traj = odeint(wrapper_model, x0, time_points, rtol=1e-4, atol=1e-4,
                          method=integration_method)
        xt = traj[-1, :]
        return xt, traj

    def compute_fid(self, dataset, number_images):
        with torch.no_grad():
            source_dir = os.path.join(self.args.save_path, "final")
            # Now compute features only on the 3000
            gen_feat = self.ft_extractor.get_dir_features(source_dir)
            if self.args.dataset == 'celeba' or self.args.dataset == "celeba64":
                test_feat = self.ft_extractor.get_features(CelebADataset(
                    img_dir, partition_csv, partition=2, transform=transforms.Compose([transforms.CenterCrop(178), transforms.Resize([self.args.dim_image, self.args.dim_image]),])), name=f"celeba{self.args.dim_image}_test")
            elif self.args.dataset == "cifar10":
                test_feat = self.ft_extractor.get_features(
                    CIFAR10(train=False, root="data", download=True), name="cifar10_test")
            else:
                ValueError("Dataset not available to compute FID")
            fid_val = FID().compute_metric(test_feat,
                                           None, gen_feat[:self.args.num_images_fid])
        return fid_val

    def sample2D(self, test_loader):
        batch_size = self.args.batch_size_gen
        gen_samples, traj = self.generate_samples(batch_size=batch_size)
        # fig, ax = plt.subplots()
        # plot_test_data(ax, next(iter(test_loader))[0])
        # plot_paths(ax, traj)
        # os.makedirs(self.args.save_path +
        #             f"/final", exist_ok=True)
        # fig.tight_layout()
        # # ax.set_xlim(-8, 10)
        # # ax.set_ylim(-6, 6)
        # ax.set_xlim(-11, 11)
        # ax.set_ylim(-11, 11)
        # plt.savefig(self.args.save_path +
        #             f"/final/{self.args.loss_type}_{self.args.class_object}_{batch_size}.png", dpi=300)

        # compute metrics
        W2 = compute_W2(traj, next(iter(test_loader))[0].cpu().numpy())
        path_energy, norm_path_energy_val = compute_path_energy(
            traj, self.model, self.device)

        metrics_filename = os.path.join(self.args.save_path, 'metrics.txt')
        with open(metrics_filename, 'a') as file:
            file.write(f"W2 Test N points {batch_size} Value {W2}\n")
            # file.write(
            #     f"Path energy x0-x_gen N points {batch_size} Value {path_energy}\n")
            file.write(
                f"Normalized path energy x0-x_gen N points {batch_size} Value {norm_path_energy_val}\n")

    def sample(self, test_loader):
        batch_size = self.args.batch_size_gen
        num_images_fid = self.args.num_images_fid

        source_dir = os.path.join(self.args.save_path, "final")
        num_existing_imgs = 0
        if os.path.isdir(source_dir):
            num_existing_imgs = len([f for f in os.listdir(source_dir)
                                     if os.path.isfile(os.path.join(source_dir, f))])
        max_batch = num_images_fid // batch_size + 1
        if num_existing_imgs < num_images_fid:
            if num_existing_imgs > 0:
                start_batch = num_existing_imgs // batch_size
            else:
                start_batch = 0
            for batch in tqdm.trange(start_batch, max_batch+1):
                print(f"Batch {batch} with size {batch_size}")
                xt, __ = self.generate_samples(batch_size=batch_size)
                rescaled_images = xt.view(
                    [-1, 3, self.args.dim_image, self.args.dim_image])  # .clip(-1, 1)
                rescaled_images = rescaled_images / 2 + 0.5

                os.makedirs(self.args.save_path +
                            f"/final", exist_ok=True)
                for i, img in tqdm.tqdm(enumerate(rescaled_images)):
                    save_image(img, self.args.save_path +
                               f"/final/{self.args.loss_type}_{self.args.class_object}_image_{batch_size * batch + i}.png")
        # number_images = batch_size * self.args.max_batch - 1
        fid_val = self.compute_fid(test_loader, num_images_fid)
        fid_filename = os.path.join(self.args.save_path, 'fid.txt')
        with open(fid_filename, 'a') as file:
            file.write(
                f'FeatExtractor {self.ft_extractor_name} Integration method {self.args.integration_method} Sampling steps {self.args.sampling_steps} t0 {self.model.time_start} t_end {self.model.time_end} FID{self.args.num_images_fid//1000}k {fid_val}\n')

    def plot_trajectories(self, batch_size=16, n_samples=10):
        """
        Plot trajectories of the first n_samples in the batch.
        Each sample is shown at t = 0.1, ..., 0.9, 1.0 (10 time points).
        """
        time_points = torch.linspace(
            self.model.time_start, self.model.time_end, int(self.args.sampling_steps) + 1, device=self.device)

        print("self.model.time_start", self.model.time_start)
        for k in range(10):
            with torch.no_grad():
                x0 = torch.randn(1, 3,
                                 self.args.dim_image, self.args.dim_image, device="cpu").to(self.device)
                traj = []
                times = []
                xt = x0
                for i, t in enumerate(time_points):
                    t_ = t * torch.ones(len(xt), device=self.device)
                    v_t = self.model.get_velocity(xt, t_)
                    xt = xt + v_t / int(self.args.sampling_steps)
                    if i % 10 == 0:
                        print(i, t)
                        traj.append(xt.detach() / 2 + 0.5)
                        times.append(t)

            fig, axes = plt.subplots(1, 11, figsize=(20, 3))
            for j, ax in enumerate(axes):
                im = traj[j].detach().cpu()[0]
                print(im.shape)
                ax.imshow(im.permute(1, 2, 0))
                ax.axis("off")
                ax.set_title("t={:2.2f}".format(times[j]))
            fig.tight_layout()

            out_path = os.path.join(self.args.save_path,
                                    f"trajectory_sample{k}_euler.png")
            fig.savefig(out_path, dpi=200)
            plt.close(fig)
            print(f"Saved trajectory plot to {out_path}")

    def run_method(self, data_loaders, degradation, sigma_noise, H_funcs=None):

        # Create the directory if it doesn't exist
        self.args.save_path = self.args.save_path + \
            f"/{self.args.integration_method}_{self.args.sampling_steps}/"
        print(self.args.save_path)
        os.makedirs(self.args.save_path, exist_ok=True)

        if self.args.dim_image == 2:
            self.sample2D(data_loaders[self.args.eval_split])
        else:
            # self.plot_trajectories(batch_size=self.args.batch_size_gen, n_samples=10)

            # Solve the inverse problem
            self.sample(
                data_loaders[self.args.eval_split])
