import torch
import torchvision
import torchvision.transforms as v2
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from PIL import Image
import pandas as pd
import os
import warnings
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import logging
import sklearn.datasets
import torch.distributions as D


class DataLoaders:
    def __init__(self, dataset_name, batch_size_train, batch_size_test, args):
        self.dataset_name = dataset_name
        print(self.dataset_name)
        self.batch_size_train = batch_size_train
        self.batch_size_test = batch_size_test
        self.args = args

    def load_data(self):

        if self.dataset_name == 'celeba' or self.dataset_name == 'celeba64':
            transform = v2.Compose([
                v2.CenterCrop(178),
                v2.Resize((self.args.dim_image, self.args.dim_image)),
                v2.ToTensor(),
                v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            # Paths
            img_dir = './data/celeba/img_align_celeba/'
            partition_csv = './data/celeba/list_eval_partition.csv'

            # Datasets
            train_dataset = CelebADataset(
                img_dir, partition_csv, partition=0, transform=transform)
            val_dataset = CelebADataset(
                img_dir, partition_csv, partition=1, transform=transform)
            test_dataset = CelebADataset(
                img_dir, partition_csv, partition=2, transform=transform)

            train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size_train,
                shuffle=True,
                collate_fn=custom_collate)
            val_loader = DataLoader(
                val_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)

        elif self.dataset_name == 'celebahq':

            transform = v2.Compose([
                v2.Resize(256),
                v2.ToTensor(),         # Convert images to PyTorch tensor
            ])

            test_dir = './data/celebahq/test/'
            test_dataset = CelebAHQDataset(
                test_dir, batchsize=self.batch_size_test, transform=transform)
            train_loader = None
            val_loader = None
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)

        elif self.dataset_name == 'afhq_cat':
            # transform should include a linear transform 2x - 1
            transform = v2.Compose([
                v2.Resize((256, 256)),
                v2.ToTensor(),
                v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            # transform = False
            img_dir_test = './data/afhq_cat/test/cat/'
            img_dir_val = './data/afhq_cat/val/cat/'
            img_dir_train = './data/afhq_cat/train/cat/'
            test_dataset = AFHQDataset(
                img_dir_test, batchsize=self.batch_size_test, transform=transform)
            val_dataset = AFHQDataset(
                img_dir_val, batchsize=self.batch_size_test, transform=transform)
            train_dataset = AFHQDataset(
                img_dir_train, batchsize=self.batch_size_test, transform=transform)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)
            val_loader = DataLoader(
                val_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)
            train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size_train,
                shuffle=True,
                collate_fn=custom_collate, drop_last=True)

        elif self.dataset_name == 'circles':
            train_dataset = CIRCLES(100000)
            test_dataset = CIRCLES(10000)

            train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size_train,
                shuffle=True,
                collate_fn=custom_collate)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=True,
                collate_fn=custom_collate)
            val_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)

        elif self.dataset_name == 'moons':
            train_dataset = MOONS(100000)
            test_dataset = MOONS(10000)

            train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size_train,
                shuffle=True,
                collate_fn=custom_collate)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=True,
                collate_fn=custom_collate)
            val_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)

        elif self.dataset_name == 'gmm':
            train_dataset = GMM(6, 100000)
            test_dataset = GMM(6, 10000)

            train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size_train,
                shuffle=True,
                collate_fn=custom_collate)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=True,
                collate_fn=custom_collate)
            val_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)
        elif self.dataset_name == 'cifar10':
            if self.args.random_flip:
                print("Random flip applied")
                transform_train = v2.Compose([
                    v2.RandomHorizontalFlip(),
                    v2.ToTensor(),
                    v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
            else:
                transform_train = v2.Compose([
                    v2.ToTensor(),
                    v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])

            transform = v2.Compose([
                v2.ToTensor(),
                v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

            train_dataset = datasets.CIFAR10(
                root="./data",
                train=True,
                download=True,
                transform=transform_train,
            )
            test_dataset = datasets.CIFAR10(
                root="./data",
                train=False,
                download=True,
                transform=transform,
            )
            val_dataset_size = int(len(test_dataset) * 0.1)
            print(len(test_dataset), val_dataset_size)
            test_dataset_size = len(test_dataset) - val_dataset_size
            test_dataset, val_dataset = torch.utils.data.random_split(test_dataset, [test_dataset_size,
                                                                      val_dataset_size],  generator=torch.Generator().manual_seed(0))

            train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size_train,
                shuffle=True,
                collate_fn=custom_collate)
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)
            val_loader = DataLoader(
                val_dataset,
                batch_size=self.batch_size_test,
                shuffle=False,
                collate_fn=custom_collate)

        else:
            raise ValueError("The dataset your entered does not exist")

        data_loaders = {'train': train_loader,
                        'test': test_loader, 'val': val_loader}

        return data_loaders


class CelebADataset(Dataset):
    def __init__(self, img_dir, partition_csv, partition, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.partition = partition

        # Load the partition file correctly
        partition_df = pd.read_csv(
            partition_csv, header=0, names=[
                'image', 'partition'], skiprows=1)
        self.img_names = partition_df[partition_df['partition']
                                      == partition]['image'].values

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)

        if not os.path.exists(img_path):
            warnings.warn(f"File not found: {img_path}. Skipping.")
            return None, None

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, 0


class CelebAHQDataset(Dataset):
    """CelebA HQ dataset."""

    def __init__(self, data_dir, batchsize, transform=None):
        self.files = os.listdir(data_dir)
        self.root_dir = data_dir
        self.num_imgs = len(os.listdir(self.root_dir))
        self.transform = transform
        self.batchsize = batchsize

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):
        img_name = self.files[idx]
        img_path = os.path.join(self.root_dir, img_name)

        if not os.path.exists(img_path):
            warnings.warn(f"File not found: {img_path}. Skipping.")
            return None, None

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
            image = 2 * image - 1
        image = image.float()

        return image, 0


class AFHQDataset(Dataset):
    """AFHQ Cat dataset."""

    def __init__(self, img_dir, batchsize, category='cat', transform=None):
        self.files = os.listdir(img_dir)
        self.num_imgs = len(self.files)
        self.batchsize = batchsize
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):
        img_name = self.files[idx]
        img_path = os.path.join(self.img_dir, img_name)

        if not os.path.exists(img_path):
            warnings.warn(f"File not found: {img_path}. Skipping.")
            return None, None

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, 0


class CIRCLES(Dataset):
    def __init__(self, n_samples):
        self.n_samples = n_samples
        self.data = sklearn.datasets.make_circles(
            n_samples=n_samples, shuffle=True, noise=None, random_state=None, factor=0.999)[0]
        self.data = torch.tensor(self.data, dtype=torch.float32)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.data[idx].unsqueeze(0).unsqueeze(-1) * 10, 0


class MOONS(Dataset):
    def __init__(self, n_samples):
        self.n_samples = n_samples
        self.data = sklearn.datasets.make_moons(
            n_samples=n_samples, shuffle=True, noise=0.01)[0]
        self.data = torch.tensor(self.data, dtype=torch.float32)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        # remove 1.5 on coordinate y
        return self.data[idx].unsqueeze(0).unsqueeze(-1) * 5 - 1.5 * torch.ones(2).unsqueeze(0).unsqueeze(-1), 0


class GMM(Dataset):
    def __init__(self, n_gmm, n_samples):
        self.n_samples = n_samples
        mix = D.Categorical(torch.ones(n_gmm,))
        means = torch.tensor(
            np.array([(8 * np.cos(k * 2 * np.pi / n_gmm), 8 * np.sin(k * 2 * np.pi / n_gmm)) for k in range(n_gmm)]))

        comp = D.Independent(D.Normal(means, 0.5 * torch.ones(n_gmm, 2)), 1)
        self.gmm = D.mixture_same_family.MixtureSameFamily(mix, comp)

    def __len__(self):
        # return self.args.n_data
        return 10000

    def __getitem__(self, idx):
        point = self.gmm.sample_n(1).float()
        return point.unsqueeze(-1), 0


def custom_collate(batch):
    # Filter out None values

    batch = list(filter(lambda x: x[0] is not None, batch))
    if len(batch) == 0:
        return torch.tensor([]), torch.tensor([])
    return torch.utils.data._utils.collate.default_collate(batch)


logging.basicConfig(level=logging.INFO)
