import os
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.datasets import make_classification, make_blobs, make_circles, make_moons, make_friedman1
from ucimlrepo import fetch_ucirepo

import torch
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms


class TabularDatasetManager:
    def __init__(self, root=None, test_size=0.2, random_state=42):
        self.root = root if root else os.environ.get("DATA", 'data/')
        self.test_size = test_size
        self.random_state = random_state
        self.datasets = {
            'Breast Cancer Wisconsin Diagnostic': 17,
            'Spambase': 94,
            'Musk': 75,
            'Dry Bean': 602,
            'MAGIC Gamma Telescope': 159,
            'Adult': 2,
            'Statlog (Shuttle)': 148,
            'CDC Diabetes Health Indicators': 891,
            'Poker Hand': 158
        }
        self.local_datasets = {
            'easy_without_test': {'Breast Cancer Wisconsin Diagnostic': 'breast_cancer_wisconsin_diagnostic',
                                  'CDC Diabetes Health Indicators': 'CDC',
                                  'MAGIC Gamma Telescope': 'magic_gamma_telescope',
                                  'Musk': 'musk_version_2',
                                  'Spambase': 'spambase',
                                  'Dry Bean': 'DryBeanDataset',
                                  'Adult': 'adult',
                                  },
            'easy_with_test': {'Statlog (Shuttle)': 'statlog_shuttle',
                               'Poker Hand': 'poker_hand',
                               },
        }

    def get_names(self):
        return list(self.datasets.keys())

    def fetch_and_preprocess(self, dataset_name):
        if dataset_name not in self.datasets:
            raise ValueError(f"Dataset {dataset_name} is not supported.")

        dataset_id = self.datasets[dataset_name]
        dataset = fetch_ucirepo(id=dataset_id)
        X = dataset.data.features
        y = dataset.data.targets
        if dataset_id == 2:
            X = X.copy()
            X.drop('fnlwgt', axis=1, inplace=True)
            X.drop('education', axis=1, inplace=True)
            y = y.replace({'<=50K.': '<=50K', '>50K.': '>50K'})

        # Ensure y is a Series or ndarray
        if isinstance(y, pd.DataFrame):
            y = y.iloc[:, 0]

        # Convert labels to numeric values if necessary
        if y.dtype == 'object' or y.dtype.name == 'category':
            y = LabelEncoder().fit_transform(y)

        while min(y) > 0:
            y = y - 1

        # Convert non-numeric features to numeric if necessary
        if X.select_dtypes(include=['object', 'category']).shape[1] > 0:
            X = pd.get_dummies(X)

        # Standardize features
        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        return X, y

    def get_loader(self, dataset_name):
        X, y = self.fetch_and_preprocess(dataset_name)

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.test_size,
                                                            random_state=self.random_state)
        train_loader = (torch.tensor(X_train, dtype=torch.float32),
                        torch.tensor(y_train, dtype=torch.long))
        test_loader = (torch.tensor(X_test, dtype=torch.float32),
                          torch.tensor(y_test, dtype=torch.long))

        in_dim = X.shape[1]
        out_dim = len(np.unique(y))

        return train_loader, test_loader, in_dim, out_dim

    def get_local_data(self, dataset_name):
        if dataset_name in self.local_datasets['easy_without_test']:
            dir = os.path.join(self.root, 'easy_without_test', self.local_datasets['easy_without_test'][dataset_name])
            data = pd.read_csv(os.path.join(dir, 'data.csv'))
            train_data, test_data = train_test_split(data, test_size=self.test_size, random_state=None)
        elif dataset_name in self.local_datasets['easy_with_test']:
            dir = os.path.join(self.root, 'easy_with_test', self.local_datasets['easy_with_test'][dataset_name])
            train_data = pd.read_csv(os.path.join(dir, 'train_data.csv'))
            test_data = pd.read_csv(os.path.join(dir, 'test_data.csv'))
        X_train = train_data.iloc[:, :-1].values
        y_train = train_data.iloc[:, -1].values
        X_test = test_data.iloc[:, :-1].values
        y_test = test_data.iloc[:, -1].values
        while min(y_train) > 0:
            y_train = y_train - 1
        while min(y_test) > 0:
            y_test = y_test - 1
        X_train = torch.tensor(X_train, dtype=torch.float32)
        y_train = torch.tensor(y_train, dtype=torch.long)
        X_test = torch.tensor(X_test, dtype=torch.float32)
        y_test = torch.tensor(y_test, dtype=torch.long)
        train_loader = (X_train, y_train)
        test_loader = (X_test, y_test)
        in_dim = X_train.shape[1]
        out_dim = len(np.unique(y_train))
        return train_loader, test_loader, in_dim, out_dim


class ToyDatasetManager:
    def __init__(self):
        pass

    def get_names(self):
        return list(self._get_dataset_generators().keys())

    def _get_dataset_generators(self):
        return {
            'classification': self._generate_classification,
            'blobs': self._generate_blobs,
            'circles': self._generate_circles,
            'moons': self._generate_moons,
            'friedman1': self._generate_friedman1,
        }

    def _generate_classification(self, n_samples):
        X, y = make_classification(n_samples=n_samples, n_features=20, n_informative=2, n_redundant=0, class_sep=2.0, n_clusters_per_class=1, n_classes=2)
        return X, y

    def _generate_blobs(self, n_samples):
        X, y = make_blobs(n_samples=n_samples, centers=3, n_features=2, cluster_std=2)
        return X, y

    def _generate_circles(self, n_samples):
        X, y = make_circles(n_samples=n_samples, noise=0.1, factor=0.5)
        return X, y

    def _generate_moons(self, n_samples):
        X, y = make_moons(n_samples=n_samples, noise=0.1)
        return X, y

    def _generate_friedman1(self, n_samples):
        X, y = make_friedman1(n_samples=n_samples, n_features=10, noise=0.1)
        return X, y

    def get_loader(self, dataset_name, train_samples, val_samples, test_samples, batch_size=32, shuffle=True):
        dataset_generators = self._get_dataset_generators()

        if dataset_name not in dataset_generators:
            raise ValueError(f"Dataset '{dataset_name}' is not supported.")

        total_samples = train_samples + val_samples + test_samples
        X, y = dataset_generators[dataset_name](total_samples)

        X_train, X_temp, y_train, y_temp = train_test_split(X, y, train_size=train_samples / total_samples,
                                                            random_state=42)
        X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp,
                                                        test_size=test_samples / (val_samples + test_samples),
                                                        random_state=42) if val_samples > 0 else (
        None, X_temp, None, y_temp)

        X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long if 'classification' in dataset_name else torch.float32)
        X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
        y_test_tensor = torch.tensor(y_test, dtype=torch.long if 'classification' in dataset_name else torch.float32)

        train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)

        if val_samples > 0:
            X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
            y_val_tensor = torch.tensor(y_val, dtype=torch.long if 'classification' in dataset_name else torch.float32)
            val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
            val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
        else:
            val_dataloader = None

        test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

        input_dim = X.shape[1]
        output_dim = y.shape[1] if len(y.shape) > 1 else 1

        return train_dataloader, val_dataloader, test_dataloader, input_dim, output_dim


class CVDatasetManager:
    def __init__(self, val_split=0.2):
        self.val_split = val_split
        self.batch_size = 10
        self.data_root = os.environ.get('DATA', "data/")
        self.dataset_name = None
        self.num_classes = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        self.cifar_size = 224

        self.train_transform1 = transforms.Compose([
            transforms.Lambda(self.convert_to_rgb),
            transforms.Resize((32, 32)),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.val_transform1 = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.Lambda(self.convert_to_rgb),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.train_transform2 = transforms.Compose([
            transforms.Resize((self.cifar_size, self.cifar_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(self.cifar_size, padding=8),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.val_transform2 = transforms.Compose([
            transforms.Resize((self.cifar_size, self.cifar_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    @staticmethod
    def convert_to_rgb(img):
        return img.convert("RGB")

    def set_cifar_size(self, cifar_size):
        self.cifar_size = cifar_size
        self.train_transform2 = transforms.Compose([
            transforms.Resize((self.cifar_size, self.cifar_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(self.cifar_size, padding=8),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.val_transform2 = transforms.Compose([
            transforms.Resize((self.cifar_size, self.cifar_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _set_num_classes(self):
        if self.dataset_name == 'MNIST':
            return 10
        elif self.dataset_name == 'FMNIST':
            return 10
        elif self.dataset_name == 'CIFAR10':
            return 10
        elif self.dataset_name == 'CIFAR100':
            return 100
        else:
            raise ValueError("Unknown dataset name")

    def _get_dataset(self, train=True):
        if self.dataset_name == 'MNIST':
            return datasets.MNIST(root=self.data_root, train=train, download=True,
                                  transform=self.train_transform1 if train else self.val_transform1)
        elif self.dataset_name == 'FMNIST':
            return datasets.FashionMNIST(root=self.data_root, train=train, download=True,
                                         transform=self.train_transform1 if train else self.val_transform1)
        elif self.dataset_name == 'CIFAR10':
            return datasets.CIFAR10(root=self.data_root, train=train, download=True,
                                    transform=self.train_transform2 if train else self.val_transform2)
        elif self.dataset_name == 'CIFAR100':
            return datasets.CIFAR100(root=self.data_root, train=train, download=True,
                                     transform=self.train_transform2 if train else self.val_transform2)
        else:
            raise ValueError("Unknown dataset name")

    def _prepare_loaders(self):
        # Create train and validation datasets and loaders
        full_train_dataset = self._get_dataset(train=True)
        num_train = len(full_train_dataset)
        num_val = int(num_train * self.val_split)
        num_train -= num_val

        if self.val_split > 0:
            train_dataset, val_dataset = random_split(full_train_dataset, [num_train, num_val])
            if self.dataset_name in ['MNIST', 'FMNIST']:
                val_dataset.dataset.transform = self.val_transform1
            else:
                val_dataset.dataset.transform = self.val_transform2
            train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
            val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        else:
            train_loader = DataLoader(full_train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
            val_loader = None

        # Create test dataset and loader
        test_dataset = self._get_dataset(train=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)

        return train_loader, val_loader, test_loader

    def set_batch_size(self, batch_size):
        self.batch_size = batch_size
        if self.dataset_name is not None:
            self.train_loader, self.val_loader, self.test_loader = self._prepare_loaders()

    def get_loader(self, dataset_name, batch_size=None):
        if batch_size is not None:
            self.set_batch_size(batch_size)
        self.dataset_name = dataset_name
        self.num_classes = self._set_num_classes()
        self.train_loader, self.val_loader, self.test_loader = self._prepare_loaders()
        if batch_size is not None:
            self.set_batch_size(batch_size)
        return self.train_loader, self.val_loader, self.test_loader, self.num_classes

    def get_names(self):
        return ['MNIST', 'FMNIST', 'CIFAR10', 'CIFAR100']
    