"""
Sort the covertype dataset by the distance to the water body, and split the dataset into 12 domains

func:
    get_domain: Get the covertype dataset for a specific domain
    get_domains: Get the covertype dataset for multiple domains
    get_source: Get the source dataset (randomly split 0.8 train and 0.2 test)
"""

import torch
import os
from torch.utils.data import Dataset, TensorDataset, Subset
import pandas as pd
import numpy as np
from typing import Tuple
    
def load_covertype(data_dir) -> Dataset:
    file = data_dir + f"covertype.pt"
    if os.path.exists(file):
        return torch.load(file)
    
    data_dir = os.path.join(data_dir, f"covertype.data.gz")
    df = pd.read_csv(data_dir, header=None, compression="gzip")
    data = df.to_numpy()
    xs = data[:, :54]
    xs = (xs - np.mean(xs, axis=0)) / np.std(xs, axis=0)
    ys = data[:, 54] - 1

    # Keep the first 2 types of crops, these comprise majority of the dataset.
    keep = ys <= 1
    xs = xs[keep]
    ys = ys[keep]

    # Sort by (horizontal) distance to water body.
    dist_to_water = xs[:, 3]
    indices = np.argsort(dist_to_water, axis=0)
    xs = xs[indices]
    ys = ys[indices]

    dataset = TensorDataset(torch.tensor(xs, dtype=torch.float32), torch.tensor(ys, dtype=torch.long))

    torch.save(dataset, file)
    return dataset
        
domain_img_num = [0, 50000] + [40000] * 10 + [45141]
n2idx = {
    2: [1, 12],
    3: [1, 8, 12],
    4: [1, 5, 9, 12],
    5: [1, 4, 7, 10, 12],
    6: [1, 4, 6, 8, 10, 12],
    7: [1, 3, 5, 7, 9, 11, 12],
    12: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
}

def get_domain(data_dir, domains_num: int, idx: int)->Dataset:
    """
    ! idx starts from 1.
    """
    assert domains_num in n2idx.keys(), f"Invalid number of domains: {domains_num}"
    assert idx in range(1, domains_num + 1), f"Invalid domain index: {idx}"
    dataset = load_covertype(data_dir)
    assert sum(domain_img_num) <= len(dataset), f"Invalid number {sum(domain_img_num)} <= {len(dataset)}"
    strart = sum(domain_img_num[:n2idx[domains_num][idx-1]])
    end = strart + domain_img_num[n2idx[domains_num][idx-1]]
    return Subset(dataset, list(range(strart, end)))

def get_domains(data_dir, domains_num: int)->list[Dataset]:
    assert domains_num in n2idx.keys(), f"Invalid number of domains: {domains_num}"
    dataset = load_covertype(data_dir)
    assert sum(domain_img_num) <= len(dataset), f"Invalid number {sum(domain_img_num)} <= {len(dataset)}"
    domains = []
    for i in n2idx[domains_num]:
        strart = sum(domain_img_num[:i])
        end = strart + domain_img_num[i]
        indices = list(range(strart, end))
        domains.append(Subset(dataset, indices))
    return domains

def get_source(data_dir, shuffle: bool = False)->Tuple[Dataset, Dataset]:
    dataset = get_domain(data_dir, 2, 1)
    if shuffle:
        if hasattr(dataset, 'indices'):
            dataset.indices = [dataset.indices[i] for i in np.random.permutation(len(dataset.indices))] # for Subset object
        else:
            dataset.data = dataset.data[np.random.permutation(len(dataset))]
    if hasattr(dataset, 'indices'):
        indeces = dataset.indices
        train_indeces = indeces[:int(len(indeces)*0.8)]
        test_indeces = indeces[int(len(indeces)*0.8):]
        return Subset(dataset, train_indeces), Subset(dataset, test_indeces)
    else:
        return dataset[:int(len(dataset)*0.8)], dataset[int(len(dataset)*0.8):]
    

# ------------ test-code ------------

def show_covertype(data_dir, domains_datasets: list[Dataset]):
    for col, sub_dataset in enumerate(domains_datasets):
        data, target = sub_dataset[0]
        print(data.shape)
        print(target)

    
if __name__ == "__main__":
    data_dir = "./data/covertype/"
    
    data = get_domain(data_dir, 2, 2)
    print(len(data))
    for i in range(len(data)):
        print(i)
        d, t = data[i]
    # show_covertype(data_dir, tr)
