import os
import torch as th
import pandas as pd
import numpy as np
from collections import namedtuple
from torch.utils.data import DataLoader, Dataset


def df_to_tensor(df):
    return th.from_numpy(df.values).float()


def convert_to_tensor_if_needed(x, squeeze=False):
    if isinstance(x, pd.DataFrame):
        x = df_to_tensor(x)
    elif isinstance(x, np.ndarray):
        x = th.from_numpy(x)
    elif th.is_tensor(x):
        pass
    else:
        raise Exception("Type not supported")
    return x.squeeze() if squeeze else x


TabData = namedtuple('TabData', ['x_train', 'y_train',
                                 'x_val', 'y_val',
                                 'x_train_total', 'y_train_total',
                                 'x_test', 'y_test',
                                 'x_test_n_train', 'y_test_n_train'])

ImbDatasetInfo = namedtuple('ImbDatasetInfo', ['num_min', 'num_maj', 'num_features', 'imb_ratio', 'num_samples'])

class TabularDataset(Dataset):
    def __init__(
            self,
            x_tensor,
            y_tensor,
            class_cond
    ):
        super().__init__()
        self.x_tensor = x_tensor
        self.y_tensor = y_tensor
        self.class_cond = class_cond

    def __len__(self):
        return self.x_tensor.shape[0]

    def __getitem__(self, idx):
        out_dict = {}                           ## dictionary to be compatible with legacy code
        if self.class_cond:
            out_dict["y"] = self.y_tensor[idx]
        return self.x_tensor[idx], out_dict


def load(path):
    x, y = th.load(path)
    """
    df = pd.read_csv('datasets/Iris.csv') # df = pd.read_csv(path)
    x = df.iloc[:, [1, 2, 3, 4]]
    y = df.iloc[:, [5]]
    """
    x = convert_to_tensor_if_needed(x)
    y = convert_to_tensor_if_needed(y, squeeze=True)
    assert (x.shape[0] == y.shape[0])
    return x, y


def load_tabular_data(train_pt, validation_pt, test_pt):
    x_train, y_train = load(train_pt)
    x_test, y_test = load(test_pt)
    if validation_pt:
        x_val, y_val = load(validation_pt)
        x_train_total, y_train_total = (
            th.cat((x_train, x_val)),
            th.cat((y_train, y_val)),
        )
    else:
        x_val, y_val = None, None
        x_train_total, y_train_total = x_train, y_train
    # test_n_train
    x_test_n_train = th.cat((x_train, x_test), 0)
    y_test_n_train = th.cat((y_train, y_test), 0)
    return TabData(x_train=x_train, y_train=y_train,
                   x_val=x_val, y_val=y_val,
                   x_train_total=x_train_total, y_train_total=y_train_total,
                   x_test=x_test, y_test=y_test,
                   x_test_n_train=x_test_n_train, y_test_n_train=y_test_n_train)


def tabular_data_loader(train_pt, validation_pt, batch_size, type, class_cond):
    x_train, y_train, x_val, y_val, x_train_total, y_train_total = load_train_and_val(train_pt, validation_pt)
    x_maj = x_train_total[y_train_total == 0]
    x_min = x_train_total[y_train_total == 1]
    if type == 'min_and_maj':
        ds_min_and_maj = TabularDataset(x_train_total.float(), y_train_total, class_cond)
        loader = DataLoader(ds_min_and_maj, batch_size=batch_size, shuffle=True)
    elif type == 'min':
        ds_min = TabularDataset(x_min.float(), th.ones(x_min.shape[0]), class_cond)
        loader = DataLoader(ds_min, batch_size=batch_size, shuffle=True)
    elif type == 'maj':
        ds_maj = TabularDataset(x_maj.float(), th.zeros(x_maj.shape[0]), class_cond)
        loader = DataLoader(ds_maj, batch_size=batch_size, shuffle=True)
    else:
        raise Exception("Argument 'type' is not valid")
    while True:
        yield from loader

def get_dataset_info(x, y):
    num_features = x.shape[1]
    num_samples = x.shape[0]
    num_min = int(th.count_nonzero(y))
    num_maj = num_samples - num_min
    imb_ratio = round(num_maj/num_min, 1)
    return ImbDatasetInfo(num_min=num_min,
                          num_maj=num_maj,
                          num_features=num_features,
                          imb_ratio=imb_ratio,
                          num_samples=num_samples)


## used in tabular_train.py
def load_and_get_dataset_info(dataset_pt):
    x, y = th.load(dataset_pt)
    return get_dataset_info(x, y)


def print_dataset_characteristics(dataset_name, data:TabData):
    train_info = get_dataset_info(data.x_train_total, data.y_train_total)
    test_info = get_dataset_info(data.x_test, data.y_test)
    assert train_info.num_features == test_info.num_features
    if dataset_name is not None:
        print("======= DATASET ANALYSIS -  {:<16} ==========".format(dataset_name))
    else:
        print("======= DATASET ANALYSIS ==============================")
    print(f'|features|        = {train_info.num_features}')
    print(f'|train|           = {train_info.num_samples}')
    print(f'train imb ratio   = {train_info.imb_ratio:.3}')
    print(f'|minority train|  = {train_info.num_min}')
    print(f'|test|            = {test_info.num_samples}')
    print(f'test imb ratio    = {test_info.imb_ratio:.3}')
    print(f'|minority test|   = {test_info.num_min}')
    print(f'|dataset|         = {test_info.num_samples + train_info.num_samples}')
    dataset_imb_ratio = (train_info.num_maj + test_info.num_maj) / (train_info.num_min + test_info.num_min)
    print(f'dataset imb ratio = {dataset_imb_ratio:.3}')
    print("=======================================================")


def num_samples_by_type(train_info:ImbDatasetInfo, train_with:str):
    if train_with == 'min':
        return train_info.num_min
    elif train_with == 'maj':
        return train_info.num_maj
    elif train_with == 'min_and_maj':
        return train_info.num_min + train_info.num_maj
    else:
        raise Exception("Type not supported")
