# Data science imports
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer, RobustScaler

# Project imports
from src import logger
from src.dataset_inference.attacks import get_mia

# Other imports

import pickle
import torch
import traceback
import gc
import numpy as np
from collections import Counter


def create_df(files):
    members_scores_string = {}
    nonmembers_scores_string = {}

    for f in files:
        try:
            with open(f, 'rb') as file:
                d = pickle.load(file)
            # logger.info(f"{f} has been successfully loaded.")

            for i, group_id in enumerate(d['group_id']):
                secret_type = d['args']['secret_type']
                dataset = d['args']['dataset']
                
                key = (group_id, secret_type, dataset, d['args']['model_name'])
                nonmembers_scores_string[key] = d['nonmembers_scores_string'][i]
                members_scores_string[key] = d['members_scores_string'][i]

            del d
            gc.collect()

        except Exception as e:
            logger.error(f"Error processing file {f}: {e}")
            traceback.print_exc()

    return members_scores_string, nonmembers_scores_string

def create_grouped_split(model_list, members_scores_string, nonmembers_scores_string, filter_features, number_of_nonmembers=None):
    '''
    Returns a dictionary with the following structure:
    set_id: {
        "real": list of tensors of shape (num_models, num_samples, num_features) containing the real samples,
        "generated": list of tensors of shape (num_models, num_samples, num_features) containing the generated samples,
        "real_idx": list of indexes of the real samples,
        "generated_idx": list of indexes of the generated samples,
        "info": (group_id, secret_type, dataset)
    }
    '''
    df = pd.DataFrame(members_scores_string.keys(), columns=['group_id', 'secret_type', 'dataset', 'model_name'])

    set_id_mapping = {(group_id, secret_type, dataset): i for i, ((group_id, secret_type, dataset), _) in enumerate(df.groupby(['group_id', 'secret_type', 'dataset']))} 
    
    dataset = {
        v: {} for k, v in set_id_mapping.items()
    }
    

    mem_mia_vector, non_mem_mia_vector = [], []
    for dataset_key, curr_df in df.groupby(['group_id', 'secret_type', 'dataset']):
        group_id, secret_type, dataset_name = dataset_key

        if len(set(model_list) - set(curr_df['model_name'])) > 0:
            logger.info(f'skip {group_id} {secret_type} {dataset_name}')
            logger.info(f"set: {set(model_list) - set(curr_df['model_name'])}")
            logger.info(set(curr_df['model_name']))
            continue
        non_mem_list = []
        mem_list = []

        for model_name in model_list:
            non_member_score_string = nonmembers_scores_string.get((group_id, secret_type, dataset_name, model_name), None)
            member_score_string = members_scores_string.get((group_id, secret_type, dataset_name, model_name), None)
            
            if non_member_score_string is None or member_score_string is None:
                logger.warning(f"Skipping {group_id} {secret_type} {dataset_name} {model_name}")
                continue
            
            non_mem_mia_vector = torch.stack(
                [v for k,v in get_mia(non_member_score_string).items() if filter_features(k)], 
                dim=-1
            )
            non_mem_mia_vector = non_mem_mia_vector[torch.randperm(non_mem_mia_vector.shape[0])[:number_of_nonmembers], :]

            mem_mia_vector = torch.stack(
                [v for k,v in get_mia(member_score_string).items() if filter_features(k)],
                dim=-1
            )
            
            non_mem_list.append(non_mem_mia_vector)
            mem_list.append(mem_mia_vector)
        
        if len(non_mem_list) == 0 or len(mem_list) == 0:
            logger.warning(f"Skipping {group_id} {secret_type} {dataset_name}")
            continue

        non_mem_list = torch.cat(non_mem_list, dim=-1)
        mem_list = torch.cat(mem_list, dim=-1)
        dataset[set_id_mapping[dataset_key]]['real'] = mem_list
        dataset[set_id_mapping[dataset_key]]['generated'] = non_mem_list
        dataset[set_id_mapping[dataset_key]]['info'] = (group_id, secret_type, dataset_name)
    
    dataset = {k: v for k, v in dataset.items() if len(v) > 0}
    
    # Special case: copyright-traps
    real = []
    generated = []
    keep_copyright_traps = Counter([secret_type.split('_')[-1] for (group_id, secret_type, dataset_name), v in set_id_mapping.items()]).most_common(1)[0][0]
    for k, v in set_id_mapping.items():
        group_id, secret_type, dataset_name = k
        if dataset_name == 'copyright-traps':
            real.append(dataset[v]['real'])
            if secret_type.split('_')[-1] == keep_copyright_traps:
                generated.append(dataset[v]['generated'])
            del dataset[v]
    if len(real) > 0 or len(generated) > 0:
        dataset[len(set_id_mapping)] = {
            "real": torch.cat(real, dim=0),
            "generated": torch.cat(generated, dim=0),
            "info": (0, 'seq_len_AAA_n_rep_BBB', 'copyright-traps'),
        }
    return dataset

def create_split(dataset):
    if len(dataset) == 0:
        return None, None, None

    x = []
    y = []
    set_id = []
    
    for g_id, data in dataset.items():
        real, gen = data['real'], data['generated']
        gen = gen.reshape(-1, gen.shape[-1])
        x += [real, gen]
        y += [torch.ones(real.shape[0]), torch.zeros(gen.shape[0])]
        set_id += [g_id] * (len(real) + len(gen))
    
    return torch.cat(x, dim=0), torch.cat(y, dim=0), torch.tensor(set_id)


def remove_feature_outliers(metrics, remove_frac, outliers):
    if not torch.is_tensor(metrics):
        metrics = torch.tensor(metrics, dtype=torch.float32)

    total_elements = metrics.shape[0]
    sorted_ids = torch.argsort(metrics)
    elements_to_remove_each_side = max(1, int(total_elements * remove_frac / 2))

    if elements_to_remove_each_side * 2 > total_elements:
        raise ValueError("remove_frac is too large, resulting in no elements left.")

    lowest_ids = sorted_ids[:elements_to_remove_each_side]
    highest_ids = sorted_ids[-elements_to_remove_each_side:]

    trimmed_metrics = metrics.clone()

    if outliers == "zero":
        trimmed_metrics[lowest_ids] = 0.0
        trimmed_metrics[highest_ids] = 0.0
    elif outliers in ["mean", "mean+p-value"]:
        mean_val = trimmed_metrics.mean().item()
        trimmed_metrics[lowest_ids] = mean_val
        trimmed_metrics[highest_ids] = mean_val
    elif outliers == "clip":
        # Clip outliers to closest permissible values
        highest_val_permissible = trimmed_metrics[highest_ids[0]].item()
        lowest_val_permissible = trimmed_metrics[lowest_ids[-1]].item()
        trimmed_metrics[highest_ids] = highest_val_permissible
        trimmed_metrics[lowest_ids] = lowest_val_permissible
    elif outliers == "randomize":
        # Remove outliers entirely
        mask = torch.ones_like(trimmed_metrics, dtype=torch.bool)
        mask[lowest_ids] = False
        mask[highest_ids] = False
        trimmed_metrics = trimmed_metrics[mask]
    else:
        pass
    
    return trimmed_metrics


def clean_outliers(data, remove_frac=0.05, outliers="zero"):
    """
    Apply outlier removal column-wise for a 2D tensor.
    """
    if not torch.is_tensor(data):
        data = torch.tensor(data, dtype=torch.float32)
    if '-' in outliers:
        outliers, remove_frac = outliers.split('-')
        remove_frac = float(remove_frac)

    cleaned_columns = []
    for i in range(data.shape[1]):
        col = data[:, i]
        cleaned_col = remove_feature_outliers(col, remove_frac=remove_frac, outliers=outliers)
        
        if cleaned_col.shape[0] != data.shape[0]:
            raise ValueError("Outlier removal with 'randomize' leads to mismatched lengths.")
        cleaned_columns.append(cleaned_col.unsqueeze(1))
    return torch.cat(cleaned_columns, dim=1)


def scale_data(train_metrics, val_metrics, scaler_type="standard"):
    """
    Normalize train and validation metrics using pre-defined scalers from scikit-learn.
    scaler_type options:
    - 'standard': StandardScaler (mean=0, std=1)
    - 'minmax':   MinMaxScaler (maps to [0,1])
    - 'l2':       Normalizer(norm='l2')
    """
    # Choose the scaler
    if scaler_type == "standard":
        scaler = StandardScaler()
    elif scaler_type == "minmax":
        scaler = MinMaxScaler()
    elif scaler_type == "l2":
        scaler = Normalizer(norm='l2')
    elif scaler_type == "robust":
        scaler = RobustScaler()
    else:
        logger.warning(f"Unsupported scaler type: {scaler_type}. No scaler applied." )
        return train_metrics, val_metrics

    scaler.fit(np.array(train_metrics))
    train_arr = scaler.transform(np.array(train_metrics))
    val_arr = scaler.transform(np.array(val_metrics))

    train_arr, val_arr =  torch.tensor(train_arr, dtype=torch.float32), torch.tensor(val_arr, dtype=torch.float32)
    
    return train_arr, val_arr, scaler
