import os
import pandas as pd
import numpy as np
import torch
import torchvision
import time
import robustness.datasets as rob_datasets

import utils.helper as h

import warnings
warnings.filterwarnings("ignore")

CLASS_ID_TO_NAMES = {
    'cifar': {
        0: 'plane',
        1: 'automobile',
        2: 'bird',
        3: 'cat',
        4: 'deer',
        5: 'dog',
        6: 'frog',
        7: 'horse',
        8: 'ship',
        9: 'truck',
    }
}

# class SpecialCIFAR10(datasets.CIFAR10):
class SpecialCIFAR10(rob_datasets.CIFAR):
    def __init__(self, imgs, labels, **kwargs):
        super(SpecialCIFAR10, self).__init__(**kwargs)
        self.data = imgs
        self.targets = labels

        self.transform = None
        if not torch.is_tensor(self.data):
            self.targets = torch.from_numpy(self.targets)
            self.data = torch.from_numpy(self.data)/255.
            self.data = torch.permute(self.data, (0,3,1,2))

        self.mean = torch.mean(self.data, dim=(0,2,3))
        self.std = torch.std(self.data, dim=(0,2,3))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)
            target = self.transform(target)

        return img, target

class SpecialCIFAR100(rob_datasets.CIFAR100):
    def __init__(self, imgs, labels, **kwargs):
        super(SpecialCIFAR100, self).__init__(**kwargs)
        self.data = imgs
        self.targets = labels

        self.transform = None
        if not torch.is_tensor(self.data):
            self.targets = torch.from_numpy(self.targets)
            self.data = torch.from_numpy(self.data)/255.
            self.data = torch.permute(self.data, (0,3,1,2))

        self.mean = torch.mean(self.data, dim=(0,2,3))
        self.std = torch.std(self.data, dim=(0,2,3))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)
            target = self.transform(target)

        return img, target

def create_csv_SemBias(path):
    words_file = open(path, 'r')
    all_words = []
    all_labels = []
    for _, line in enumerate(words_file):
        print(line)
        words = '\t'.join(line.split(':')).split('\t')
        words = [word.replace('\n', '') for word in words]
        print(words)
        for word_idx, word in enumerate(words):
            if word not in all_words:
                if word_idx in [0, 1]: # first two words are gender-definition
                    all_labels.append(0)
                elif word_idx in [2, 3, 4, 5]: # four words that have similar meaning (gender neutral)
                    all_labels.append(1)
                else: # last two words are gender-stereotype
                    all_labels.append(2)
                all_words.append(word)
        print(all_words)
    print(all_words, len(all_words), len(all_labels))
    import pandas as pd
    df = pd.DataFrame({'word': all_words, 'label':all_labels})
    print(df)
    df.to_csv('SemBias_subset.csv', index=False)

def create_csv_SemBias_pairs(path):
    words_file = open(path, 'r')
    all_words = []
    all_labels = []
    for _, line in enumerate(words_file):
        print(line)
        word_pairs = line.split('\t')
        word_pairs = [words.replace('\n', '') for words in word_pairs]
        for idx, word_pair in enumerate(word_pairs):
            if word_pair not in all_words:
                if idx in [0]: # first word pair are gender-definition
                    all_labels.append(0)
                elif idx in [1, 2]: # two next word pairs that have similar meaning (gender neutral)
                    all_labels.append(1)
                else: # last two words are gender-stereotype
                    all_labels.append(2)
                all_words.append(word_pair)
        print(all_words)
    print(all_words, len(all_words), len(all_labels))
    import pandas as pd
    df = pd.DataFrame({'words': all_words, 'label':all_labels})
    print(df)
    df.to_csv('SemBias_subset_pairs.csv', index=False)

def main():
    create_csv_SemBias('data/SemBias_subset')
    # create_csv_SemBias_pairs('data/SemBias_subset')

if __name__=='__main__':
    main()


# code from: https://github.com/modestyachts/ImageNetV2_pytorch/issues/2

import pathlib
import tarfile
import requests
import shutil

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

URLS = {"matched-frequency" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-matched-frequency.tar.gz",
        "threshold-0.7" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-threshold0.7.tar.gz",
        "top-images": "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-top-images.tar.gz",
        "val": "https://imagenet2val.s3.amazonaws.com/imagenet_validation.tar.gz"}

FNAMES = {"matched-frequency" : "imagenetv2-matched-frequency-format-val",
        "threshold-0.7" : "imagenetv2-threshold0.7-format-val",
        "top-images": "imagenetv2-top-images-format-val",
        "val": "imagenet_validation"}


V2_DATASET_SIZE = 10000
VAL_DATASET_SIZE = 50000

class ImageNetV2Dataset(Dataset):
    def __init__(self, variant="matched-frequency", transform=None, location="."):
        self.dataset_root = pathlib.Path(f"{location}/ImageNetV2-{variant}/")
        self.tar_root = pathlib.Path(f"{location}/ImageNetV2-{variant}.tar.gz")
        self.fnames = list(self.dataset_root.glob("**/*.jpeg"))
        self.fnames = sorted(self.fnames, key=str)
        self.transform = transform
        assert variant in URLS, f"unknown V2 Variant: {variant}"
        if not self.dataset_root.exists() or len(self.fnames) != V2_DATASET_SIZE:
            if not self.tar_root.exists():
                print(f"Dataset {variant} not found on disk, downloading....")
                response = requests.get(URLS[variant], stream=True)
                total_size_in_bytes= int(response.headers.get('content-length', 0))
                block_size = 1024 #1 Kibibyte
                progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
                with open(self.tar_root, 'wb') as f:
                    for data in response.iter_content(block_size):
                        progress_bar.update(len(data))
                        f.write(data)
                progress_bar.close()
                if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
                    assert False, f"Downloading from {URLS[variant]} failed"
            print("Extracting....")
            tarfile.open(self.tar_root).extractall(f"{location}")
            shutil.move(f"{location}/{FNAMES[variant]}", self.dataset_root)
            self.fnames = list(self.dataset_root.glob("**/*.jpeg"))
        

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, i):
        img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name)
        if self.transform is not None:
            img = self.transform(img)
        return img, label
    