import torch
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import torch.utils.data as data_utils
from typing import *
import numpy as np


def features_to_contexts(feat: torch.Tensor, K: int) -> torch.Tensor:
    D = feat.shape[-1]
    ctx = torch.zeros(K, K*D, device=feat.device)

    for k in range(K):
        ctx[k, k*D:(k+1)*D] = feat

    return ctx


def load_raw_uci_dset(name: str) -> Tuple[data_utils.DataLoader, int, int]:
    assert name in ['bean', 'letter', 'magic', 'avila', 'pendigits']

    try:
        data = pd.read_excel(f'datasets/{name}.xlsx').values
    except:
        data = pd.read_csv(f'datasets/{name}.csv').values

    if name in ['bean', 'magic', 'avila', 'pendigits']:
        X, y = data[:, :-1].astype('float'), data[:, -1]
    else:  # letter
        X, y = data[:, 1:].astype('float'), data[:, 0]

    # Convert string labels into integer labels
    y = LabelEncoder().fit_transform(y)
    dataloader = data_utils.TensorDataset(torch.tensor(X).float(), torch.tensor(y).long())

    # Some properties
    D, K = X.shape[-1], len(np.unique(y))

    return dataloader, D, K


if __name__ == '__main__':
    for dset in ['bean', 'letter', 'magic', 'avila', 'pendigits']:
        _, D, K = load_raw_uci_dset(dset)
        print(dset, D, K)

    # D = 5
    # K = 3
    # feat = torch.randn(D)

    # ctx = features_to_contexts(feat, K)
    # assert ctx.shape == (K, D*K)

    # print(ctx)
