"""
This file is based on the code from https://github.com/qinenergy/cotta.
"""
import torch
from torchvision import transforms, datasets
from torch.utils.data import Subset, Dataset, TensorDataset, DataLoader
from typing import Sequence, Tuple
from pathlib import Path
import numpy as np
import warnings
import os

from .corruption import CORRUPTIONS, CORRUPTIONS_DIR_NAMES, BenchmarkDataset
from .utils import singleton_thread
from model.utils import get_torchvision_model


@singleton_thread
def download_imagenet_c(data_dir: str):
    """download the imagenet-c dataset from zenodo
    wget --content-disposition https://zenodo.org/record/2235448/files/blur.tar?download=1
    wget --content-disposition https://zenodo.org/record/2235448/files/digital.tar?download=1
    wget --content-disposition https://zenodo.org/record/2235448/files/extra.tar?download=1
    wget --content-disposition https://zenodo.org/record/2235448/files/noise.tar?download=1
    wget --content-disposition https://zenodo.org/record/2235448/files/weather.tar?download=1

    tar -zxvf blur.tar
    tar -zxvf digital.tar
    tar -zxvf extra.tar
    tar -zxvf noise.tar
    tar -zxvf weather.tar
    """
    data_dir = Path(data_dir) / CORRUPTIONS_DIR_NAMES[BenchmarkDataset.image_net]
    os.makedirs(data_dir, exist_ok=True)
    if not os.path.exists(data_dir / "blur.tar"):
        os.system(f"wget --content-disposition https://zenodo.org/record/2235448/files/blur.tar?download=1")
    if not os.path.exists(data_dir / "digital.tar"):
        os.system(f"wget --content-disposition https://zenodo.org/record/2235448/files/digital.tar?download=1")
    if not os.path.exists(data_dir / "extra.tar"):
        os.system(f"wget --content-disposition https://zenodo.org/record/2235448/files/extra.tar?download=1")
    if not os.path.exists(data_dir / "noise.tar"):
        os.system(f"wget --content-disposition https://zenodo.org/record/2235448/files/noise.tar?download=1")
    if not os.path.exists(data_dir / "weather.tar"):
        os.system(f"wget --content-disposition https://zenodo.org/record/2235448/files/weather.tar?download=1")
    
    for c in CORRUPTIONS:
        if not os.path.exists(data_dir / f"{c}"):
            os.system(f"tar -zxvf {data_dir}/blur.tar -C {data_dir}")
            os.system(f"tar -zxvf {data_dir}/digital.tar -C {data_dir}")
            os.system(f"tar -zxvf {data_dir}/extra.tar -C {data_dir}")
            os.system(f"tar -zxvf {data_dir}/noise.tar -C {data_dir}")
            os.system(f"tar -zxvf {data_dir}/weather.tar -C {data_dir}")
            break
    return data_dir

def load_imagenetc(
    data_dir: str = './data/imagenetc/',
    corruptions: Sequence[str] = CORRUPTIONS,
    severity: int = 5,
    n_examples: int = 50000,
    shuffle: bool = False,
) -> Dataset:
    assert len(corruptions) == 1, "so far only one corruption is supported"
    # TODO: generalize this (although this would probably require writing a function similar to `load_corruptions_cifar`
    
    root_dir = download_imagenet_c(data_dir)
    data_folder_path = Path(root_dir) / corruptions[0] / str(severity)

    imagenet = datasets.ImageFolder(data_folder_path, transform=transforms.ToTensor())
    # assert n_examples <= len(imagenet), f"n_examples is greater than the number of examples in the dataset: {len(imagenet)}"
    if n_examples > len(imagenet):
        warnings.warn(f"n_examples {n_examples} is greater than the number of examples in the dataset: {len(imagenet)}")
    loader = DataLoader(imagenet, batch_size=n_examples, shuffle=shuffle)
    x_test, y_test = next(iter(loader))
    return TensorDataset(x_test, y_test)

def load_imagenet(
    data_dir: str = './data/imagenet/',
    examples_num: int = None,
    shuffle: bool = False,
) -> Tuple[Dataset, Dataset]:

    _, preprocess = get_torchvision_model()
    d_tr = datasets.ImageFolder(data_dir+"train", transform=preprocess)
    d_ts = datasets.ImageFolder(data_dir+"val", transform=preprocess)
    
    if shuffle:
        indices = np.random.permutation(len(d_tr.samples))
        d_tr.samples = [d_tr.samples[i] for i in indices]
        d_tr.targets = [d_tr.targets[i] for i in indices]
        indices = np.random.permutation(len(d_ts.samples))
        d_ts.samples = [d_ts.samples[i] for i in indices]
        d_ts.targets = [d_ts.targets[i] for i in indices]
    
    if examples_num is not None:
        if examples_num > len(d_tr):
            warnings.warn(f"examples_num is greater than the number of training examples, setting examples_num to {len(d_tr)}")
        if examples_num > len(d_ts):
            warnings.warn(f"examples_num is greater than the number of validation examples, setting examples_num to {len(d_ts)}")
        d_tr = Subset(d_tr, range(min(examples_num, len(d_tr))))
        d_ts = Subset(d_ts, range(min(examples_num, len(d_ts))))
    
    return d_tr, d_ts

d_num = 50000
n2idx = {
    2: [0, 5],
    3: [0, 3, 5],
    4: [0, 2, 4, 5],
    5: [0, 2, 3, 4, 5],
    6: [0, 1, 2, 3, 4, 5],
}

def get_domain(data_dir: str, data_c_dir: str, domains_num: int, corruption: str, idx: int = None) -> Dataset:
    severity = n2idx[domains_num][idx]
    try:
        if severity == 0:
            _, d_ts = load_imagenet(data_dir, examples_num=d_num)
            return d_ts
        return load_imagenetc(data_c_dir, severity=severity, corruptions=[corruption], n_examples=d_num)
    except Exception as e:
        raise Exception(f"Error loading ImageNet-C from '{data_dir}' with domains_num {domains_num}, idx {idx}, severity {severity}, corruption '{corruption}'.")

def get_domains(data_dir: str, data_c_dir: str, domains_num: int, corruption: str) -> list[Dataset]:
    domains = []
    for i in range(domains_num):
        domains.append(get_domain(data_dir, data_c_dir, domains_num, corruption, i))
    return domains


# ------------ test-code ------------

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    
    data_dir = "/home/me/share/imagenet/"
    data_c_dir = "./data/imagenetc/"
    
    d_tr, d_ts = load_imagenet(data_dir)
    print(d_tr)
    print(d_ts)
    print(len(d_tr))
    print(len(d_ts))
    
    dataset = get_domain(data_dir, data_c_dir, 6, "glass_blur", 0)
    print(len(dataset))
    dataset = get_domain(data_dir, data_c_dir, 6, "glass_blur", 1)
    print(len(dataset))
    
    # plt.figure(figsize=(10, 10))
    # for i in range(10):
    #     plt.subplot(2, 5, i + 1)
    #     plt.imshow(d_tr[i][0].permute(1, 2, 0))
    #     plt.title(d_tr[i][1])
    #     plt.axis("off")
    # plt.tight_layout()
    # plt.savefig("imagenet.png")
    # plt.show()
    
    # plt.figure(figsize=(64, 120))
    # for i,c in enumerate(CORRUPTIONS):
    #     for j in range(6):
    #         print(f"{c} {j}")
    #         dataset = get_domain(data_dir, data_c_dir, 6, c, j)
    #         print(len(dataset))
    #         plt.subplot(len(CORRUPTIONS), 6, i*6+j+1)
    #         plt.imshow(dataset[0][0].permute(1, 2, 0))
    #         plt.title(f"{c} {j}")
    #         plt.axis("off")
    # plt.tight_layout()
    # plt.savefig("imagenetc.png")
    # plt.show()