import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset, Subset
import torchvision
from tqdm import tqdm
from torchvision import transforms
from torch.optim import lr_scheduler
from PIL import Image

# ----------------------------

class LabeledDataset(Dataset):
    def __init__(self, data, labels, weights=None, transform=None, type=None):
        super().__init__()
        assert len(data) == len(labels)
        assert len(data) == len(weights) if weights is not None else True
        self.data = data
        self.labels = labels
        self.weights = weights
        self.transform = transform
        self.type = type

    def __getitem__(self, index):
        data, label = self.data[index], self.labels[index]
        if self.transform:
            if self.type == "mnist":
                data = Image.fromarray(data.numpy(), mode="L")
            data = self.transform(data)
        if self.weights is not None:
            return data, label, self.weights[index]
        return data, label
    
    def __len__(self):
        return len(self.data)

def other2LabeledDataset(dataset):
    if isinstance(dataset, Subset):
        indices = dataset.indices
        dataset = other2LabeledDataset(dataset.dataset)
        return LabeledDataset(dataset.data[indices], dataset.labels[indices], dataset.weights[indices] if dataset.weights is not None else None, dataset.transform, dataset.type)
    elif isinstance(dataset, TensorDataset):
        assert len(dataset.tensors) == 2
        return LabeledDataset(dataset.tensors[0], dataset.tensors[1])
    elif isinstance(dataset, torchvision.datasets.MNIST) or isinstance(dataset, torchvision.datasets.CIFAR10) or isinstance(dataset, torchvision.datasets.CIFAR100):
        transform = getattr(dataset, 'transform', None)
        return LabeledDataset(dataset.data, dataset.targets, transform = transform, type = "mnist")
    elif isinstance(dataset, LabeledDataset):
        return LabeledDataset(dataset.data, dataset.labels, dataset.weights, dataset.transform, dataset.type)
    else:        
        raise ValueError(f"Unsupported dataset type: {type(dataset)}")

def get_labeled_data(dataset, model, device, discard_ratio, batch_size, num_workers=4):
    model.eval()
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    labels = []
    confidence = []
    with torch.no_grad():
        for x in tqdm(loader, desc="Getting labeled data"):
            data = x[0].to(device)
            outputs = model(data)
            labels.append(outputs.argmax(dim=1))
            confidence.append(outputs.max(dim=1)[0])
    labels = torch.cat(labels, dim=0)
    confidence = torch.cat(confidence, dim=0)
    
    # get the threshold of confidence
    alpha = torch.quantile(confidence, discard_ratio)
    indices = torch.nonzero(confidence >= alpha, as_tuple=True)[0].cpu()
    
    # get the indices of the dataset
    dataset = other2LabeledDataset(dataset)
    dataset = LabeledDataset(dataset.data[indices], labels[indices], dataset.weights[indices] if dataset.weights is not None else None, dataset.transform, dataset.type)
    return dataset

# ----------------------------

def get_scheduler(optimizer, **kwargs):
    """Epoch-based Schedulers for PyTorch Optimizers."""
    if kwargs.lr_policy == "lambda":
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + kwargs.epoch_count - kwargs.niter) / float(kwargs.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif kwargs.lr_policy == "step":
        scheduler = lr_scheduler.StepLR(optimizer, step_size=kwargs.lr_decay_iters, gamma=0.1)
    elif kwargs.lr_policy == "plateau":
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, threshold=0.01, patience=5)
    else:
        return NotImplementedError("learning rate policy [%s] is not implemented", kwargs.lr_policy)
    return scheduler


if __name__ == "__main__":
    dataset = torchvision.datasets.MNIST(root='./data/color_mnist', train=True, download=True, transform=transforms.ToTensor())
    dataset = other2LabeledDataset(dataset)
    print(dataset)
    print(len(dataset))
    subset = Subset(dataset, [0, 1, 2, 3, 4])
    print(len(subset))
    dataset = other2LabeledDataset(subset)
    print(dataset)
    print(len(dataset))
