import os
import requests
import random
import numpy as np
import pandas as pd
import torch
from io import StringIO
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
import hashlib
import os
from transformers import default_data_collator

# Function to generate hash from binary mask
def generate_mask_hash(mask):
    mask_str = ''.join(map(str, mask))
    return hashlib.sha256(mask_str.encode()).hexdigest()

def set_seed(seed=42):
    if seed:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_optimizer_function(name):
    """Get the optimizer function given its name."""
    if name.lower() == 'adam':
        return torch.optim.Adam
    elif name.lower() == 'sgd':
        return torch.optim.SGD
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

### Details pulled from public datasets on OpenXAI: https://github.com/AI4LIFE-GROUP/OpenXAI
dataverse_prefix = 'https://dataverse.harvard.edu/api/access/datafile/'
dataverse_ids = {
    'train': {
        'adult': '8550940', 'compas': '8550936', 'gaussian': '8550929', 'german': '8550931',
        'gmsc': '8550934', 'heart': '8550932', 'heloc': '8550942', 'pima': '8550937',
    },
    'test': {
        'adult': '8550933', 'compas': '8550944', 'gaussian': '8550941', 'german': '8550930',
        'gmsc': '8550939', 'heart': '8550935', 'heloc': '8550943', 'pima': '8550938',
    }
}
labels = {'adult': 'income', 'compas': 'risk', 'gaussian': 'target', 'german': 'credit-risk',
        'gmsc': 'SeriousDlqin2yrs', 'heart': 'TenYearCHD', 'heloc': 'RiskPerformance', 'pima': 'Outcome'}
scalars = {'minmax': MinMaxScaler(), 'standard': StandardScaler()}
input_dims = {'adult': 13, 'compas': 7, 'gaussian': 20, 'german': 20, 'gmsc': 10, 'heart': 15, 'heloc': 23, 'pima': 8}

def scale_data(data, scale):
    # Transform data based on scale
    if scale is None:
        return data
    elif scale in scalars.keys():
        return scalars[scale].fit_transform(data)
    else:
        raise NotImplementedError("Provide one of the following transformations: {'minmax', 'standard', None}")

def load_openxai_dataset(name, train=True, download=False, scale=None):
    """
    Load the dataset from OpenXAI storage.
    """
    # Path and filename of the dataset
    name = name.lower()
    datasets_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1])
    path = f'{datasets_dir}/{name}/data/'
    filename = f'{name}_train.csv' if train else f'{name}_test.csv'

    # Download the dataset if it does not exist, otherwise, load it (unless download=True)
    if not os.path.isfile(path + filename):
        os.makedirs(path, exist_ok=True)
        r = requests.get(dataverse_prefix + dataverse_ids['train' if train else 'test'][name], allow_redirects=True)
        dataset = pd.read_csv(StringIO(r.text), sep='\t')
        dataset.to_csv(path + filename, index=False)
    elif download:
        os.remove(path + filename)
        return load_openxai_dataset(name, train=train, download=True, scale=scale)
    else:
        dataset = pd.read_csv(path + filename)

    # Get the target label and split the data into features and target
    label = labels[name]
    targets = dataset[label].values  # targets
    data = dataset.drop(label, axis=1).values  # features
    data = scale_data(data, scale)
    
    # Return as torch tensors
    return torch.FloatTensor(data), torch.LongTensor(targets)
    
def load_sklearn_dataset(name, train=True, scale='minmax'):
    if name == "california":
        # Load the dataset
        california_housing = fetch_california_housing(as_frame=True)
        home_data = california_housing.frame
        features = home_data.drop(columns='MedHouseVal')
        target = home_data['MedHouseVal']

        # Split the data
        X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.33, random_state=0)

        # Select the subset of data based on 'train'
        X, y = (X_train, y_train) if train else (X_test, y_test)

        # Scale the data
        X_scaled = scale_data(X, scale)

        # Convert to torch tensors
        X_tensor = torch.from_numpy(np.array(X_scaled)).float()
        y_tensor = torch.from_numpy(y.values).float().unsqueeze(1)  # Ensure target is 2D

        return X_tensor, y_tensor
    else:
        raise NotImplementedError
    
def eval_loop(model, dataloader, eval_fn, device):
    """
    Evaluate the model on the given dataloader 
    using the provided evaluation function
    e.g. cross-entropy loss, accuracy, etc.
    """
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = process_batch(batch, device=device)
            outputs = model(**inputs)
            loss = eval_fn(outputs, labels)
            total_loss += loss.item() * labels.size(0)
    return total_loss / len(dataloader.dataset)


def process_batch(batch, device='cpu'):
    """
    Process a batch of data decoded from dataloader and return x and y.
    For HuggingFace datasets, x is a dictionary with 'input_ids', 'attention_mask', and 'token_type_ids'.
    For other datasets, x and y are returned as is.
    """
    if isinstance(batch, dict) and 'input_ids' in batch:
        # HuggingFace dataset
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device),
            'token_type_ids': batch['token_type_ids'].to(device)
        }
        labels = batch['labels'].to(device)
    else:
        # Other datasets
        inputs, labels = batch
        inputs = {"x": inputs.to(device)}
        labels = labels.to(device)
    return inputs, labels

