import pandas as pd
from typing import Tuple, Dict
import torch
from sklearn.model_selection import train_test_split

import datasets.process_data
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)


def _load_data(
        df: pd.DataFrame,
        data_config,
        split_ratios: Tuple[float, float] = (0.1, 0.9),
        seed: int = None
) -> Tuple[Dict, Dict, Dict]:

    columns_to_drop = [data_config.target_label_col]
    if data_config.target_prob_col in df.columns:
        columns_to_drop.append(data_config.target_prob_col)
    features = torch.tensor(df.drop(columns=columns_to_drop).to_numpy())

    logits = torch.tensor(df[data_config.pred_logit_col].values)
    logits = torch.clamp(logits * data_config.pred_scale, min=0, max=1)
    labels = torch.tensor(df[data_config.target_label_col].values).long()
    if data_config.target_prob_col in df.columns:
        aprobs = torch.tensor(df[data_config.target_prob_col].values)
    else:
        aprobs = labels


    (val_features, remaining_features,
     val_logits, remaining_logits,
     val_labels, remaining_labels,
     val_aprobs, remaining_aprobs) = train_test_split(
        features,
        logits,
        labels,
        aprobs,
        test_size=1 - split_ratios[0],
        random_state=seed,
        stratify=labels
    )


    (test_train_features, test_test_features,
     test_train_logits, test_test_logits,
     test_train_labels, test_test_labels,
     test_train_aprobs, test_test_aprobs) = train_test_split(
        remaining_features,
        remaining_logits,
        remaining_labels,
        remaining_aprobs,
        test_size=split_ratios[1] / (split_ratios[0] + split_ratios[1]),
        random_state=seed
    )

    return (
        {'features': val_features, 'logits': val_logits, 'labels': val_labels, 'aprobs': val_aprobs},
        {'features': test_train_features, 'logits': test_train_logits, 'labels': test_train_labels,
         'aprobs': test_train_aprobs},
        {'features': test_test_features, 'logits': test_test_logits, 'labels': test_test_labels,
         'aprobs': test_test_aprobs}
    )


def load_data(
        data_config,
        split_ratios: Tuple[float, float] = (0.1, 0.9),
        seed: int = None
) -> Tuple[Dict, Dict, Dict]:

    try:
        df = datasets.process_data.pre_load_data(data_config)

        return _load_data(df, data_config, split_ratios, seed)

    except pd.errors.ParserError as e:
        logger.error(f"data parse error: {str(e)}")
        raise
    except Exception as e:
        logger.error(f"error: {str(e)}")
        raise
