from torchvision import datasets
import torch
import os
from .constants import *
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image


class AwA2Dataset(Dataset):
    def __init__(self, cls_2_idx, idx_2_cls, transform=None, train=True):
        self.cls_2_idx = cls_2_idx
        self.idx_2_cls = idx_2_cls
        self.transform = transform

        train_csv = pd.read_csv(AWA2_TRAIN)
        val_csv = pd.read_csv(AWA2_VAl)
        test_csv = pd.read_csv(AWA2_TEST)
        if train:
            self.data_csv = pd.concat([train_csv, val_csv], ignore_index=True)
        else:
            self.data_csv = test_csv
        self.data_csv = self.data_csv.sample(frac=1, random_state=42).reset_index(
            drop=True
        )

    def __len__(self):
        return len(self.data_csv)

    def __getitem__(self, idx):
        sample = self.data_csv.iloc[idx]
        img_name = sample["img_name"].split("/JPEGImages/")[-1]
        img_path = os.path.join(AWA2_DATA, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = sample["img_name"].split("/")[-2]
        label = self.cls_2_idx[label]
        return image, label


def get_dataset(args, preprocess=None):
    if args.dataset == "cifar10":
        trainset = datasets.CIFAR10(
            root=args.out_dir, train=True, download=True, transform=preprocess
        )
        testset = datasets.CIFAR10(
            root=args.out_dir, train=False, download=True, transform=preprocess
        )
        classes = trainset.classes
        class_to_idx = {c: i for (i, c) in enumerate(classes)}
        idx_to_class = {v: k for k, v in class_to_idx.items()}
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

    elif args.dataset == "cifar100":
        trainset = datasets.CIFAR100(
            root=args.out_dir, train=True, download=True, transform=preprocess
        )
        testset = datasets.CIFAR100(
            root=args.out_dir, train=False, download=True, transform=preprocess
        )
        classes = trainset.classes
        class_to_idx = {c: i for (i, c) in enumerate(classes)}
        idx_to_class = {v: k for k, v in class_to_idx.items()}
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

    elif args.dataset == "cub":
        from .cub import load_cub_data
        from .constants import CUB_PROCESSED_DIR, CUB_DATA_DIR
        from torchvision import transforms

        num_classes = 200
        TRAIN_PKL = os.path.join(CUB_PROCESSED_DIR, "train.pkl")
        TEST_PKL = os.path.join(CUB_PROCESSED_DIR, "test.pkl")
        normalizer = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2])
        train_loader = load_cub_data(
            [TRAIN_PKL],
            use_attr=False,
            no_img=False,
            batch_size=args.batch_size,
            uncertain_label=False,
            image_dir=CUB_DATA_DIR,
            resol=224,
            normalizer=normalizer,
            n_classes=num_classes,
            resampling=True,
        )

        test_loader = load_cub_data(
            [TEST_PKL],
            use_attr=False,
            no_img=False,
            batch_size=args.batch_size,
            uncertain_label=False,
            image_dir=CUB_DATA_DIR,
            resol=224,
            normalizer=normalizer,
            n_classes=num_classes,
            resampling=True,
        )

        classes = open(os.path.join(CUB_DATA_DIR, "classes.txt")).readlines()
        classes = [a.split(".")[1].strip() for a in classes]
        idx_to_class = {i: classes[i] for i in range(num_classes)}
        classes = [classes[i] for i in range(num_classes)]
        print(len(classes), "num classes for cub")
        print(len(train_loader.dataset), "training set size")
        print(len(test_loader.dataset), "test set size")

    elif args.dataset == "ham10000":
        from .derma_data import load_ham_data

        train_loader, test_loader, idx_to_class = load_ham_data(args, preprocess)
        class_to_idx = {v: k for k, v in idx_to_class.items()}
        classes = list(class_to_idx.keys())

    elif args.dataset == "inet100":
        trainset = datasets.ImageFolder(
            root=os.path.join(args.data_path, "train"), transform=preprocess
        )
        valset = datasets.ImageFolder(
            root=os.path.join(args.data_path, "val"), transform=preprocess
        )
        testset = datasets.ImageFolder(
            root=os.path.join(args.data_path, "test_set"), transform=preprocess
        )

        trainset = torch.utils.data.ConcatDataset([trainset, valset])

        classes = valset.classes

        class_to_idx = {c: i for (i, c) in enumerate(classes)}
        idx_to_class = {v: k for k, v in class_to_idx.items()}
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

    elif args.dataset == "awa2":
        classes = open(AWA2_CLASSES).readlines()
        classes = [a.strip() for a in classes]
        class_to_idx = {c: i for i, c in enumerate(classes)}
        idx_to_class = {v: k for k, v in class_to_idx.items()}
        trainset = AwA2Dataset(class_to_idx, idx_to_class, preprocess, train=True)
        testset = AwA2Dataset(class_to_idx, idx_to_class, preprocess, train=False)
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

    else:
        raise ValueError(args.dataset)

    return train_loader, test_loader, idx_to_class, classes
