import os
from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, GTSRB, Food101, SUN397, EuroSAT, UCF101, DTD, OxfordIIITPet, MNIST, ImageFolder
from torchvision import  transforms
import torch
import numpy as np
from PIL import Image
import lmdb
import pickle
import six
import json
from collections import OrderedDict
from utils.const import IMAGENETNORMALIZE, CIFAR100NORMALIZE, CIFAR10NORMALIZE


def get_model(net, pretrain):
    torch.hub.set_dir('./cache')
    # network
    if net == "resnet18":
        if pretrain:
            from torchvision.models import resnet18, ResNet18_Weights
            network = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).cuda()
        else:
            network = resnet18(pretrained = False).cuda()
    elif net == "resnet50":
        if pretrain:
            from torchvision.models import resnet50, ResNet50_Weights
            network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).cuda()
        else:
            network = resnet50(pretrained = False).cuda()
    elif net == "instagram":
        if pretrain:
            from torch import hub
            network = hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl').cuda()
        else:
            from torchvision.models import resnext101_32x8d
            network = resnext101_32x8d(pretrained=False).cuda()
    elif net == 'mobilenet':
        if pretrain:
            from torchvision.models import mobilenet_v2
            network = mobilenet_v2(pretrained=True).cuda()
        else:
            network = mobilenet_v2(pretrained=False).cuda()
    else:
        raise NotImplementedError(f"{net} is not supported")
    # print('Network Architecture:', network)

    return network


# Imagenet Transform
def image_transform(args, transform_type):
    normalize = transforms.Normalize(mean=IMAGENETNORMALIZE['mean'], std=IMAGENETNORMALIZE['std'])
    if transform_type == 'vp':
        if args.randomcrop:
            print('Using randomcrop\n')
            train_transform = transforms.Compose([
                transforms.RandomCrop((32, 32), padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.Resize((args.input_size, args.input_size)),
                transforms.ToTensor(),
            ])
            test_transform = transforms.Compose([
                transforms.Resize((args.input_size, args.input_size)),
                transforms.ToTensor(),
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((int(args.input_size*9/8), int(args.input_size*9/8))),
                transforms.RandomCrop(args.input_size),
                transforms.RandomHorizontalFlip(),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
            ])
            test_transform = transforms.Compose([
                transforms.Resize((args.input_size, args.input_size)),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
            ])
    elif transform_type == 'ff':
        train_transform = transforms.Compose([
            transforms.Resize((252,252)),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            transforms.ToTensor(),
            normalize
        ])
        test_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            transforms.ToTensor(),
            normalize
        ])
    return train_transform, test_transform, normalize


def get_torch_dataset(args, transform_type):
    data_path = os.path.join(args.datadir, args.dataset)
    dataset = args.dataset
    train_transform, test_transform, normalize = image_transform(args, transform_type)

    if dataset == "cifar10":
        train_set = CIFAR10(data_path, train=True, transform=train_transform, download=True)
        test_set = CIFAR10(data_path, train=False, transform=test_transform, download=True)
        class_cnt = 10

    elif dataset == "cifar100":
        train_set = CIFAR100(data_path, train=True, transform=train_transform, download=True)
        test_set = CIFAR100(data_path, train=False, transform=test_transform, download=True)
        class_cnt = 100

    elif dataset == "svhn":
        train_set = SVHN(data_path, split = 'train', transform=train_transform, download=True)
        test_set = SVHN(data_path, split = 'test', transform=test_transform, download=True)
        class_cnt = 10

    elif dataset == "gtsrb":
        full_data = GTSRB(root = data_path, split = 'train', download = True)
        train_indices, val_indices = get_indices(full_data)
        train_set = Subset(GTSRB(data_path, split = 'train', transform=train_transform, download=True), train_indices)
        test_set = GTSRB(data_path, split = 'test', transform=test_transform, download=True)
        class_cnt = 43

    elif dataset == 'food101':
        train_set = Food101(data_path, split = 'train', transform=train_transform, download=True)
        test_set = Food101(data_path, split = 'test', transform=test_transform, download=True)
        class_cnt = 101

    elif dataset == 'sun397':
        img_dir = os.path.join(data_path, 'SUN397')
        train_partition_file = os.path.join(data_path, 'Partitions/Training_01.txt')
        test_partition_file = os.path.join(data_path, 'Partitions/Testing_01.txt')
        target_file = os.path.join(data_path, 'Partitions/ClassName.txt')
        full_data = SUN397Dataset(img_dir = img_dir, partition_file = train_partition_file,  target_file=target_file)
        train_indices, val_indices = get_indices(full_data)
        train_set = Subset(SUN397Dataset(img_dir = img_dir, partition_file = train_partition_file, target_file=target_file, transform=train_transform), train_indices)
        test_set = SUN397Dataset(img_dir = img_dir, partition_file = test_partition_file, target_file=target_file, transform=test_transform)
        class_cnt = 397

    elif dataset == 'eurosat':
        full_data = EuroSAT(root = data_path, split = 'train', download = True)
        train_indices, val_indices = get_indices(full_data)
        train_set = Subset(EuroSAT(data_path, split = 'train', transform=train_transform, download=True), train_indices)
        test_set = EuroSAT(data_path, split = 'test', transform=test_transform, download=True)
        class_cnt = 10

    elif dataset == 'ucf101':
        annotation_path = os.path.join(data_path, 'ucfTrainTestlist')
        data_path = os.path.join(data_path, 'UCF-101')
        full_data = UCF101(root = data_path,  annotation_path=annotation_path, frames_per_clip=1, fold=1, train = True)
        train_indices, val_indices = get_indices(full_data)
        train_set = Subset(UCF101(data_path, train = True, annotation_path=annotation_path, frames_per_clip=1, fold=1, transform=train_transform), train_indices)
        test_set = UCF101(data_path, train = False, annotation_path=annotation_path, frames_per_clip=1, fold=1, transform=test_transform)
        class_cnt = 101
    
    elif dataset == 'stanfordcars':
        data_path = os.path.join(data_path, 'car_data/car_data')
        train_set = ImageFolder(data_path+'/train/', transform=train_transform)
        test_set = ImageFolder(data_path+'/test/', transform=test_transform)
        class_cnt = 196
    
    elif dataset == 'flowers102':
        train_set = ConcatDataset([COOPLMDBDataset(root = data_path, split="train", transform = train_transform), \
                                   COOPLMDBDataset(root = data_path, split="val", transform = train_transform)])
        test_set = COOPLMDBDataset(root = data_path, split="test", transform = test_transform)
        class_cnt = 102
    
    elif dataset == 'dtd':
        train_set = ConcatDataset([DTD(root = data_path, split = 'train', transform=train_transform, download = True), \
                                DTD(root = data_path, split = 'val', transform=train_transform, download = True)])
        test_set = DTD(data_path, split = 'test', transform=test_transform, download=True)
        class_cnt = 47

    elif dataset == 'oxfordpets':
        train_set = OxfordIIITPet(data_path, split = 'trainval', transform=train_transform, download=True)
        test_set = OxfordIIITPet(data_path, split = 'test', transform=test_transform, download=True)
        class_cnt = 37
    
    elif dataset == 'mnist':
        train_set = MNIST(data_path, train = True, transform=train_transform, download=True)
        test_set = MNIST(data_path, train = False, transform=test_transform, download=True)
        class_cnt = 10

    elif dataset == 'imagenet':
        train_set = ImageFolder(os.path.join(data_path, 'ILSVRC2012_train'), transform=train_transform)
        test_set = ImageFolder(os.path.join(data_path, 'ILSVRC2012_validation'), transform=test_transform)
        class_cnt = 1000
    
    elif dataset == 'tiny_imagenet':
        train_set = ImageFolder(root=os.path.join(data_path, 'tiny-imagenet-200/train'), transform=train_transform)
        test_set = TinyImageNet(os.path.join(data_path, 'tiny-imagenet-200/val/images'), os.path.join(data_path, 'tiny-imagenet-200/val/val_annotations.txt'), 
                                os.path.join(data_path, 'tiny-imagenet-200/wnids.txt'), transform=test_transform)
        class_cnt = 200

    else:
        raise NotImplementedError(f"{dataset} not supported")
    
    if dataset == 'imagenet':
        train_loader = DataLoader(train_set, batch_size=1024, shuffle=True, num_workers=args.num_workers, pin_memory=True)
        test_loader = DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    elif dataset in ['dtd', 'oxfordpets']:
        train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=args.num_workers, pin_memory=True)
        test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    elif dataset in ['flowers102', 'stanfordcars']:
        train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=args.num_workers, pin_memory=True)
        test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    else:
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    args.class_cnt = class_cnt
    args.normalize = normalize
    print(f'Dataset information: {dataset}\n')
    print(f'{len(train_set)} images for training \t')
    print(f'{len(test_set)} images for testing\t')

    return train_loader, test_loader


def get_batch(data_loader, batch_index):
    start_index = batch_index * data_loader.batch_size
    end_index = start_index + data_loader.batch_size
    batch_data = []
    batch_targets = []
    
    for i in range(start_index, end_index):
        if i >= len(data_loader.dataset):
            break
        data, target = data_loader.dataset[i]
        batch_data.append(data)
        batch_targets.append(target)
    
    return torch.stack(batch_data), torch.tensor(batch_targets)


def get_indices(full_data):
    full_len = len(full_data)
    train_len = int(full_len * 0.9)
    indices = np.random.permutation(full_len)
    train_indices = indices[:train_len]
    val_indices = indices[train_len:]

    return train_indices, val_indices


class SubsetWithTransform(Subset):
    def __init__(self, dataset, indices, transform=None):
        super(SubsetWithTransform, self).__init__(dataset, indices)
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        if self.transform:
            x = self.transform(x)
        return x, y


class SUN397Dataset(Dataset):
    def __init__(self, img_dir, partition_file, target_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform

        with open(target_file, 'r') as f:
            self.label_names = [l.strip() for l in f.readlines()]
        self.label_idx = {name: idx for idx,name in enumerate(self.label_names)}

        self.img_names = []
        self.targets = []
        with open(partition_file, 'r') as f:
            lines = f.readlines()
        for l in lines:
            l = l.strip()
            self.img_names.append(l)
            label_name, _ = os.path.split(l)
            self.targets.append(self.label_idx[label_name])

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = self.img_dir+img_name
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        target = self.targets[idx]
        return image, target


class TinyImageNet(Dataset):
    def __init__(self, root_dir, annotations_file, label_ids_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.entries = open(annotations_file).read().strip().split('\n')

        with open(label_ids_file, 'r') as f:
            self.label_names = [l.strip() for l in f.readlines()]
        self.label_names = sorted(self.label_names)
        self.label_idx = {name: idx for idx,name in enumerate(self.label_names)}

        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, index):
        line = self.entries[index].split('\t')
        img_path, annotation = line[0], line[1]
        image = Image.open(self.root_dir + '/' + img_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
            
        return image, int(self.label_idx[annotation])


class LMDBDataset(Dataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        super().__init__()
        db_path = os.path.join(root, f"{split}.lmdb")
        self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = pickle.loads(txn.get(b'__len__'))
            self.keys = pickle.loads(txn.get(b'__keys__'))

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])

        unpacked = pickle.loads(byteflow)

        # load img
        imgbuf = unpacked[0]
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf)

        # load label
        target = unpacked[1]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        # return img, target
        return img, target

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'


class COOPLMDBDataset(LMDBDataset):
    def __init__(self, root, split="train", transform=None) -> None:
        super().__init__(root, split, transform=transform)
        with open(os.path.join(root, "split.json")) as f:
            split_file = json.load(f)
        idx_to_class = OrderedDict(sorted({s[-2]: s[-1] for s in split_file["test"]}.items()))
        self.classes = list(idx_to_class.values())


class ReverseImageFolder(ImageFolder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, index):
        return super().__getitem__(len(self) - 1 - index)  # This reverses the order of items