"""
Sort the portraits dataset by the time of the person, and every 2000 photos form a domain

func:
    get_domain: Get the portraits dataset for a specific domain
    get_domains: Get the portraits dataset for multiple domains
    get_source: Get the source dataset (randomly split 0.8 train and 0.2 test)
"""

import torch
import os
from torchvision import transforms
from torch.utils.data import Dataset, TensorDataset, Subset
from matplotlib import pyplot as plt
from tqdm import tqdm
from PIL import Image
from typing import Tuple
import numpy as np


def load_portraits(data_dir, target_size: tuple[int, int]) -> Dataset:
    file = data_dir + f"portraits_{target_size[0]}x{target_size[1]}.pt"
    if os.path.exists(file):
        return torch.load(file)
    
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ])
    
    dataset = []
    for subdir in ["F", "M"]:
        dir = data_dir + subdir
        filenames = os.listdir(dir)
        for filename in tqdm(filenames, desc=f"{subdir} data"):
            img_path = os.path.join(dir, filename)
            img = Image.open(img_path)
            img_tensor = transform(img)
            dataset.append((img_tensor, 0 if subdir == "M" else 1, filename))
    
    dataset = sorted(dataset, key=lambda x: x[2])
    tensors = torch.stack([x[0] for x in dataset])
    labels = torch.tensor([x[1] for x in dataset])
    dataset = TensorDataset(tensors, labels)
    torch.save(dataset, file)
    return dataset

d_num = 2000
n2idx = {
    2: [1, 9],
    3: [1, 5, 9],
    4: [1, 4, 6, 9],
    5: [1, 3, 5, 7, 9],
    6: [1, 2, 4, 6, 8, 9],
    9: [1, 2, 3, 4, 5, 6, 7, 8, 9],
}

def get_domain(data_dir, domains_num: int, target_size: tuple[int, int], idx: int)->Dataset:
    """
    ! idx starts from 1.
    """
    assert domains_num in n2idx.keys(), f"Invalid number of domains: {domains_num}"
    assert idx in n2idx[domains_num], f"Invalid domain index: {idx}"
    dataset = load_portraits(data_dir, target_size)
    indices = list(range(d_num * (idx - 1), d_num * idx))
    return Subset(dataset, indices)

def get_domains(data_dir, domains_num: int, target_size: tuple[int, int])->list[Dataset]:
    assert domains_num in n2idx.keys(), f"Invalid number of domains: {domains_num}"
    dataset = load_portraits(data_dir, target_size)
    domains = []
    for i in n2idx[domains_num]:
        indices = list(range(d_num * (i - 1), d_num * i))
        domains.append(Subset(dataset, indices))
    return domains

def get_source(data_dir, target_size: tuple[int, int], shuffle: bool = False)->Tuple[Dataset, Dataset]:
    dataset = get_domain(data_dir, 2, target_size, 1)
    if shuffle:
        if hasattr(dataset, 'indices'):
            dataset.indices = [dataset.indices[i] for i in np.random.permutation(len(dataset.indices))] # for Subset object
        else:
            dataset.data = dataset.data[np.random.permutation(len(dataset))]
    if hasattr(dataset, 'indices'):
        indeces = dataset.indices
        train_indeces = indeces[:int(len(indeces)*0.8)]
        test_indeces = indeces[int(len(indeces)*0.8):]
        return Subset(dataset, train_indeces), Subset(dataset, test_indeces)
    else:
        return dataset[:int(len(dataset)*0.8)], dataset[int(len(dataset)*0.8):]

# ------------ test-code ------------

def show_portraits(data_dir, domains: list[Dataset]):
    # set the number of columns and rows
    ncols = len(domains)
    nrows = 3
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 8))
    for row in range(nrows):
        for col, dataset in enumerate(domains):
            img, target = dataset[row]
            axs[row, col].imshow(img.permute(1, 2, 0).numpy())
            axs[row, col].axis("off")
            axs[row, col].set_title(f"target: {target}")
    plt.tight_layout()
    plt.savefig(data_dir + f"/portraits.png")
    plt.show()
    print("Image shape: ", img.shape)
    
def test(dataset: Dataset):
    print(len(dataset))
    from torch.utils.data import DataLoader
    loader = DataLoader(dataset, batch_size=100, shuffle=True)
    for batch in loader:
        print(batch[0].shape)
        print(batch[1])
        break
    
if __name__ == "__main__":
    data_dir = "./data/portraits/"
    domains_datasets = get_domains(data_dir, 6, (32, 32))
    # show_portraits(data_dir, domains_datasets)
    tr, ts = get_source(data_dir, (32, 32), shuffle=False)
    test(tr)
    test(ts)