import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from typing import Tuple, List, Optional
import numpy as np
from tqdm import tqdm
from sklearn.datasets import load_svmlight_file
import urllib.request
import os
import bz2


class BinaryDataset(Dataset):
    DATASET_CONFIGS = {
        'cifar10': {
            'num_classes': 10,
            'input_size': (3, 32, 32),
            'mean': (0.4914, 0.4822, 0.4465),
            'std': (0.2470, 0.2435, 0.2616),
            'dataset_class': datasets.CIFAR10,
        },
        'mnist': {
            'num_classes': 10,
            'input_size': (1, 28, 28),  
            'mean': (0.1307,),
            'std': (0.3081,),
            'dataset_class': datasets.MNIST,
        },
        'fashion_mnist': {
            'num_classes': 10,
            'input_size': (1, 28, 28),
            'mean': (0.2860,),
            'std': (0.3530,),
            'dataset_class': datasets.FashionMNIST,
        },
        'eurosat': {
            'num_classes': 10,
            'input_size': (3, 64, 64),
            'mean': (0.3444, 0.3809, 0.4076),
            'std': (0.2031, 0.1364, 0.1145),
            'dataset_class': datasets.EuroSAT,
        },
        'flowers102': {
            'num_classes': 102,
            'input_size': (3, 224, 224),
            'mean': (0.485, 0.456, 0.406),
            'std': (0.229, 0.224, 0.225),
            'dataset_class': datasets.Flowers102,
        },
        'food101': {
            'num_classes': 101,
            'input_size': (3, 224, 224),
            'mean': (0.485, 0.456, 0.406),
            'std': (0.229, 0.224, 0.225),
            'dataset_class': datasets.Food101,
        },
        'places365': {
            'num_classes': 365,
            'input_size': (3, 224, 224),
            'mean': (0.485, 0.456, 0.406),
            'std': (0.229, 0.224, 0.225),
            'dataset_class': datasets.Places365,
        },
        'a1a': {
            'num_classes': 2,
            'num_features': 123,
            'type': 'libsvm',
            'train_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a',
            'test_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a.t',
        },
        'w1a': {
            'num_classes': 2,
            'num_features': 300,
            'type': 'libsvm',
            'train_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/w1a',
            'test_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/w1a.t',
        },
        'mushrooms': {
            'num_classes': 2,
            'num_features': 112,
            'type': 'libsvm',
            'train_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/mushrooms',
            'test_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/mushrooms',
        },
        'gisette': {
            'num_classes': 2,
            'num_features': 5000,
            'type': 'libsvm',
            'train_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/gisette_scale.bz2',
            'test_url': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/gisette_scale.t.bz2',
        }
    }
    
    def __init__(self, 
                 dataset_name: str,
                 root: str, 
                 train: bool = True, 
                 download: bool = True, 
                 num_images: Optional[int] = None, 
                 classes: List[int] = [0, 1],
                 no_normalize: bool = False,
                 resize_size: Optional[Tuple[int, int]] = None,
                 shuffle_features: bool = False):
        assert len(classes) == 2, "Binary classification requires exactly 2 classes"
        assert dataset_name in self.DATASET_CONFIGS, f"Unknown dataset: {dataset_name}"
        
        self.dataset_name = dataset_name
        self.classes = classes
        self.config = self.DATASET_CONFIGS[dataset_name]
        
        if self.config.get('type') == 'libsvm':
            self._load_libsvm_data(root, train, download, num_images)
            return
        
        transform_list = [transforms.ToTensor()]
        assert dataset_name not in ['places365']
        if resize_size is not None:
            transform_list.insert(0, transforms.Resize(resize_size))
        elif dataset_name == 'eurosat':
            transform_list.insert(0, transforms.Resize((64, 64)))
        if not no_normalize:
            transform_list.append(transforms.Normalize(mean=self.config['mean'], std=self.config['std']))
        else:
            transform_list.append(transforms.Normalize(mean=0.5, std=1.))
        transform = transforms.Compose(transform_list)
        
        if dataset_name in ['eurosat', 'food101']:
            base_dataset = self.config['dataset_class'](root=root, transform=transform, download=download)
        elif dataset_name == 'places365':
            base_dataset = self.config['dataset_class'](root=root, transform=transform, download=download, small=True)
        else:
            base_dataset = self.config['dataset_class'](root=root, train=train, transform=transform, download=download)
        
        if dataset_name == 'food101':
            self._filter_data_food101(base_dataset, num_images)
        else:
            self._filter_data(base_dataset, num_images)

        if shuffle_features:
            shape = self.data.shape
            self.data = self.data.reshape(shape[0], -1)
            shuffle_indices = torch.randperm(torch.prod(torch.tensor(shape[1:])))
            self.data = self.data[:, shuffle_indices]
            self.data = self.data.reshape(*shape)
    
    def _load_libsvm_data(self, root: str, train: bool, download: bool, num_images: Optional[int]):
        dataset_name = self.dataset_name
        os.makedirs(root, exist_ok=True)
        
        if train:
            url = self.config['train_url']
            filename = f"{dataset_name}_train"
        else:
            url = self.config['test_url']
            filename = f"{dataset_name}_test"
        
        filepath = os.path.join(root, filename)
        if download and not os.path.exists(filepath):
            print(f"Downloading {dataset_name} dataset from {url}...")
            urllib.request.urlretrieve(url, filepath)
            print(f"Downloaded to {filepath}")
            if dataset_name == 'gisette':
                os.rename(filepath, filepath + '.bz2')
                with bz2.open(filepath + '.bz2', 'rb') as f_in:
                    with open(filepath, 'wb') as f_out:
                        f_out.write(f_in.read())
                os.remove(filepath + '.bz2')

        X, y = load_svmlight_file(filepath, n_features=self.config['num_features'])
        X = X.toarray()
        if dataset_name == 'mushrooms':
            y = (y - 1).astype(np.int64)
        elif dataset_name in ['w1a', 'gisette']:
            y = ((y + 1) / 2).astype(np.int64)
        else:
            assert False
        assert list(np.sort(np.unique(y))) == [0, 1], (list(np.sort(np.unique(y))))

        if num_images is not None:
            indices = np.random.choice(len(X), min(num_images, len(X)), replace=False)
            X = X[indices]
            y = y[indices]
        
        # Convert to torch tensors
        self.data = torch.from_numpy(X).float()
        self.targets = torch.from_numpy(y).long()
        print("Data size: ", self.data.shape)
        print("Targets size: ", torch.unique(self.targets, return_counts=True))
    
    def _filter_data(self, base_dataset, num_images: Optional[int]):
        data_list = []
        target_list = []
        
        imgs_labels = list((img, label) for img, label in tqdm(base_dataset))

        np.random.shuffle(imgs_labels)
        num_images_per_label = num_images // 2 if num_images is not None else None
        label_to_num = {c: 0 for c in self.classes}
        label_to_total_num = {c: 0 for c in self.classes}

        for img, label in imgs_labels:
            if label in self.classes:
                label_to_total_num[label] += 1
                if num_images is not None and label_to_num[label] >= num_images_per_label:
                    continue
                label_to_num[label] += 1
                data_list.append(img)
                target_list.append(0 if label == self.classes[0] else 1)
        assert (num_images is None or 
                len(data_list) == (min(num_images_per_label, label_to_total_num[self.classes[0]]) + 
                                   min(num_images_per_label, label_to_total_num[self.classes[1]])))
        self.data = torch.stack(data_list)
        self.targets = torch.tensor(target_list, dtype=torch.long)
    
    def _filter_data_food101(self, base_dataset, num_images: Optional[int]):
        imgs_labels = list((img, label) for label, img in enumerate(base_dataset._labels))

        np.random.shuffle(imgs_labels)
        num_images_per_label = num_images // 2 if num_images is not None else None
        label_to_num = {c: 0 for c in self.classes}
        label_to_total_num = {c: 0 for c in self.classes}

        indices = []
        for label, img in imgs_labels:
            if label in self.classes:
                label_to_total_num[label] += 1
                if num_images is not None and label_to_num[label] >= num_images_per_label:
                    continue
                label_to_num[label] += 1
                indices.append(img)
        assert (num_images is None or 
                len(indices) == (min(num_images_per_label, label_to_total_num[self.classes[0]]) + 
                                 min(num_images_per_label, label_to_total_num[self.classes[1]])))

        subset = torch.utils.data.Subset(base_dataset, indices)
        data_list = []
        target_list = []
        for img, label in tqdm(subset):
            assert label in self.classes
            data_list.append(img)
            target_list.append(0 if label == self.classes[0] else 1)

        self.data = torch.stack(data_list)
        self.targets = torch.tensor(target_list, dtype=torch.long)

    def __len__(self) -> int:
        return self.data.shape[0]

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        return self.data[idx], int(self.targets[idx].item())
