# data_utils.py

import torch
from polyilr import construct_V
import pickle
import os

'''Returns hierarchical tree structure for given dataset.'''
def get_tree(dataset, data_dir):
    if dataset == 'cifar100':
        train_path = os.path.join(data_dir, 'cifar-100-python', 'train')
        
        if not os.path.exists(train_path):
            datasets.CIFAR100(root=data_dir, train=True, download=True)
        
        with open(train_path, 'rb') as f:
            train_dict = pickle.load(f, encoding='bytes')
        
        fine_labels = train_dict[b'fine_labels']
        coarse_labels = train_dict[b'coarse_labels']
        
        # Build fine -> coarse mapping (one entry per fine class)
        fine_to_coarse = {}
        for fine, coarse in zip(fine_labels, coarse_labels):
            if fine not in fine_to_coarse:
                fine_to_coarse[fine] = coarse
        
        # Verify we have all 100 classes
        assert len(fine_to_coarse) == 100, f"Missing classes! Only found {len(fine_to_coarse)}"
        
        # Build coarse -> fine mapping
        coarse_to_fine = {i: [] for i in range(20)}
        for fine in range(100):  # Iterate over class indices 0-99
            coarse = fine_to_coarse[fine]
            coarse_to_fine[coarse].append(fine)
        
        # Sort each list (though should already be sorted)
        for coarse in coarse_to_fine:
            coarse_to_fine[coarse].sort()
        
        # Build tree
        tree = {}
        root = -1
        tree[root] = []
        
        for superclass_id in range(20):
            internal_node = -(superclass_id + 2)
            tree[root].append(internal_node)
            tree[internal_node] = coarse_to_fine[superclass_id]
        
        return tree, root
    
    else:
        raise ValueError(f"Unknown dataset: {dataset}.")

'''Returns tree structure, root, H-PhILR basis V, and optionally Helmert matrices.'''
def get_tree_and_V(dataset, data_dir):
    tree, root = get_tree(dataset, data_dir)
    
    if dataset == 'cifar100':
        num_classes = 100
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    V = construct_V(tree, root, num_classes, edge_lengths=None, return_node_info=False)
    return tree, root, V

'''Apply label smoothing to one-hot encoded labels. Moves labels from simplex boundary to interior.'''
def smooth_one_hot(labels, num_classes, label_smoothing):
    if not torch.is_tensor(labels):
        labels = torch.tensor(labels, dtype=torch.long)
    
    device = labels.device
    one_hot = torch.zeros(labels.size(0), num_classes, device=device)
    one_hot.scatter_(1, labels.unsqueeze(1), 1)
    smoothed = (1 - label_smoothing) * one_hot + label_smoothing / num_classes
    
    return smoothed