import os
import torch
from torchvision import datasets, transforms, models

import clip
from pytorchcv.model_provider import get_model as ptcv_get_model
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image

DATASET_ROOTS = {
    "imagenet_train": "YOUR_PATH/CLS-LOC/train/",
    "imagenet_val": "YOUR_PATH/ImageNet_val/",
    "cub_train":"data/CUB/train",
    "cub_val":"data/CUB/test",
    "cxr8_train":"data/CXR8/train",
    "cxr8_val":"data/CXR8/test",
    "cxr11_train":"data/CXR11/train",
    "cxr11_val":"data/CXR11/test",
    "ham10000_train":"data/HAM10000/train",
    "ham10000_val":"data/HAM10000/test",
    "awa2_train":"data/AWA2/train",
    "awa2_val":"data/AWA2/test",
    "chestXpert_train": "/home/hxu2/chexpertchestxrays-u20210408",
    "chestXpert_val": "/home/hxu2/chexpertchestxrays-u20210408",
}

LABEL_FILES = {"places365":"data/categories_places365_clean.txt",
               "imagenet":"data/imagenet_classes.txt",
               "cifar10":"data/cifar10_classes.txt",
               "cifar100":"data/cifar100_classes.txt",
               "cub":"data/cub_classes.txt", 
               "flower":"data/flower_classes.txt", 
               "cxr8":"data/cxr8_classes.txt", 
               "cxr11":"data/cxr11_classes.txt", 
               "ham10000":"data/ham10000_classes.txt", 
               "awa2":"data/awa2_classes.txt",
               "chestXpert": "/home/hxu2/chexpertchestxrays-u20210408/classes.txt",}

class ChestXpertDataset(Dataset):
    def __init__(self, dir, transform=None, split='train'):
        self.dir = dir
        self.split = split
        self.data = pd.read_csv(os.path.join(self.dir, f"{self.split}_with_labels.csv"))
        self.path = self.data['Path'].to_list()
        self.labels = self.data['Label'].to_numpy()
        if split == 'train':
            self.concepts = self.data.iloc[:, 6:-4].to_numpy()
        else:
            self.concepts = self.data.iloc[:, 7:-3].to_numpy()
        
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.dir, self.path[idx])
        with open(img_path, 'rb') as f:
            img = Image.open(f).convert('RGB')
        f.close()
        if self.transform is not None:
            img = self.transform(img)
        return {
            'img': img,
            'label': torch.LongTensor([self.labels[idx]]),
            'concept': torch.LongTensor(self.concepts[idx])
        }


def get_resnet_imagenet_preprocess():
    target_mean = [0.485, 0.456, 0.406]
    target_std = [0.229, 0.224, 0.225]
    preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),
                   transforms.ToTensor(), transforms.Normalize(mean=target_mean, std=target_std)])
    return preprocess


def get_data(dataset_name, preprocess=None):
    if dataset_name == "cifar100_train":
        data = datasets.CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=True,
                                   transform=preprocess)

    elif dataset_name == "cifar100_val":
        data = datasets.CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False, 
                                   transform=preprocess)
        
    elif dataset_name == "cifar10_train":
        data = datasets.CIFAR10(root=os.path.expanduser("~/.cache"), download=True, train=True,
                                   transform=preprocess)
        
    elif dataset_name == "cifar10_val":
        data = datasets.CIFAR10(root=os.path.expanduser("~/.cache"), download=True, train=False,
                                   transform=preprocess)
    elif dataset_name == "flower_train":
        data = datasets.Flowers102(root=os.path.expanduser("~/.cache"), download=True, split='train',
                                   transform=preprocess)
    elif dataset_name == "flower_val":
        data = datasets.Flowers102(root=os.path.expanduser("~/.cache"), download=True, split='test',
                                   transform=preprocess)
    elif dataset_name == "places365_train":
        try:
            data = datasets.Places365(root=os.path.expanduser("~/.cache"), split='train-standard', small=True, download=True,
                                       transform=preprocess)
        except(RuntimeError):
            data = datasets.Places365(root=os.path.expanduser("~/.cache"), split='train-standard', small=True, download=False,
                                   transform=preprocess)
            
    elif dataset_name == "places365_val":
        try:
            data = datasets.Places365(root=os.path.expanduser("~/.cache"), split='val', small=True, download=True,
                                   transform=preprocess)
        except(RuntimeError):
            data = datasets.Places365(root=os.path.expanduser("~/.cache"), split='val', small=True, download=False,
                                   transform=preprocess)
    elif 'chestXpert' in dataset_name:
        split = dataset_name.split('_')[-1]
        data = ChestXpertDataset(DATASET_ROOTS[dataset_name], split=split, transform=preprocess)
        
    elif dataset_name in DATASET_ROOTS.keys():
        data = datasets.ImageFolder(DATASET_ROOTS[dataset_name], preprocess)
               
    elif dataset_name == "imagenet_broden":
        data = torch.utils.data.ConcatDataset([datasets.ImageFolder(DATASET_ROOTS["imagenet_val"], preprocess), 
                                                     datasets.ImageFolder(DATASET_ROOTS["broden"], preprocess)])
    return data

def get_targets_only(dataset_name):
    pil_data = get_data(dataset_name)
    if 'flower' in dataset_name:
        return pil_data._labels
    return pil_data.targets

def get_target_model(target_name, device):
    
    if target_name.startswith("clip_"):
        target_name = target_name[5:]
        model, preprocess = clip.load(target_name, device=device)
        target_model = lambda x: model.encode_image(x).float()
    
    elif target_name == 'resnet18_places': 
        target_model = models.resnet18(pretrained=False, num_classes=365).to(device)
        state_dict = torch.load('data/resnet18_places365.pth.tar')['state_dict']
        new_state_dict = {}
        for key in state_dict:
            if key.startswith('module.'):
                new_state_dict[key[7:]] = state_dict[key]
        target_model.load_state_dict(new_state_dict)
        target_model.eval()
        preprocess = get_resnet_imagenet_preprocess()
        
    elif target_name == 'resnet18_cub':
        target_model = ptcv_get_model("resnet18_cub", pretrained=True).to(device)
        target_model.eval()
        preprocess = get_resnet_imagenet_preprocess()
    
    elif target_name.endswith("_v2"):
        target_name = target_name[:-3]
        target_name_cap = target_name.replace("resnet", "ResNet")
        weights = eval("models.{}_Weights.IMAGENET1K_V2".format(target_name_cap))
        target_model = eval("models.{}(weights).to(device)".format(target_name))
        target_model.eval()
        preprocess = weights.transforms()
        
    else:
        target_name_cap = target_name.replace("resnet", "ResNet")
        weights = eval("models.{}_Weights.IMAGENET1K_V1".format(target_name_cap))
        target_model = eval("models.{}(weights=weights).to(device)".format(target_name))
        target_model.eval()
        preprocess = weights.transforms()
    
    return target_model, preprocess