import numpy as np
import pandas as pd
import os

def split_data(data, train_frac=0.8, val_frac=0.1, test_frac=0.1, seed=42):
    assert np.isclose(train_frac + val_frac + test_frac, 1.0), "Fractions must sum to 1."

    unique_lbnr = data['lbnr'].unique()
    np.random.seed(seed)
    np.random.shuffle(unique_lbnr)

    n = len(unique_lbnr)
    train_end = int(n * train_frac)
    val_end = train_end + int(n * val_frac)

    train_lbnr = unique_lbnr[:train_end]
    val_lbnr = unique_lbnr[train_end:val_end]
    test_lbnr = unique_lbnr[val_end:]

    train_df = data[data['lbnr'].isin(train_lbnr)]
    val_df = data[data['lbnr'].isin(val_lbnr)]
    test_df = data[data['lbnr'].isin(test_lbnr)]

    return train_df, val_df, test_df

def get_biom_features(data):

    biomarker_features = data.pivot_table(
        index=['lbnr', 'samplingdate'],
        columns='analysiscode',
        values=['value', 'unit', 'laboratorium_idcode', 'referenceinterval_lowerlimit', 'referenceinterval_upperlimit'],
        aggfunc={
            'value': 'mean',
            'unit': 'first',
            'laboratorium_idcode': 'first',
            'referenceinterval_lowerlimit': 'first',
            'referenceinterval_upperlimit': 'first'
        }
    )

    biomarker_features.columns = biomarker_features.columns.swaplevel(0, 1)
    biomarker_features = biomarker_features.apply(lambda row: {
        code: (
            row.get((code, 'value'), np.nan),
            row.get((code, 'unit'), np.nan),
            row.get((code, 'laboratorium_idcode'), np.nan),
            row.get((code, 'referenceinterval_lowerlimit'), np.nan),
            row.get((code, 'referenceinterval_upperlimit'), np.nan)
        ) 
        for code in biomarker_features.columns.levels[0]
    }, axis=1)

    biomarker_features = biomarker_features.apply(pd.Series)

    return biomarker_features

def get_static_features(data):
    end_tokens = data.groupby(['lbnr', 'samplingdate'])['prediag_token_date'].first()
    labels = data.groupby(['lbnr', 'samplingdate'])['dataset'].first()

    age_min = data['age_at_sampling_in_days'].min()
    age_max = data['age_at_sampling_in_days'].max()
    
    age_at_sampling = data.groupby(['lbnr', 'samplingdate'])['age_at_sampling_in_days'].first()
    age_at_sampling_norm = pd.Series(
        (age_at_sampling - age_min) / (age_max - age_min),
        name="age_at_sampling_in_days_norm"
    )

    sex = data.groupby(['lbnr', 'samplingdate'])['sex'].first()

    return (end_tokens, labels, age_at_sampling, age_at_sampling_norm, sex)

def save_trajectory_batch(i, save_path, combined):
    batch = combined[combined["lbnr"] == i]
    sample_type = batch.iloc[0]["dataset"]
    full_dir = f"{save_path}__CD__{sample_type}"
    os.makedirs(full_dir, exist_ok=True)
    if batch.shape[0] > 1 and batch.shape[0] < 50:
        batch.to_csv(f"{full_dir}/{i}.csv")

def unify_units(df):
    print('unifying units')
    unit_mapping = {
        "npu19717": ("mg/kg", ["10^-6", "10e-6", "x 10e-6", "× 10<sup>-6</sup", "× 10^-6"]),
        "npu02593": ("10^9/l", ["10e9/l", "x 10e9/l", "× 10^9/l", "× 10<sup>9</sup>", "ï¿½ 10<sup>9</sup>"]),
        "npu02902": ("10^9/l", ["10e9/l", "× 10<sup>9</sup>", "ï¿½ 10<sup>9</sup>"]),
        "npu02636": ("10^9/l", ["10e9/l", "x 10e9/l", "× 10^9/l", "× 10<sup>9</sup>",  "ï¿½ 10<sup>9</sup>"]),
        "npu02840": ("10^9/l", ["10e9/l", "x 10e9/l", "× 10^9/l", "× 10<sup>9</sup>", "ï¿½ 10<sup>9</sup>"]),
        "npu01933": ("10^9/l", ["10e9/l", "x 10e9/l", "× 10^9/l", "× 10<sup>9</sup>", "ï¿½ 10<sup>9</sup>"]),
        "npu01349": ("10^9/l", ["10e9/l", "x 10e9/l", "× 10^9/l", "× 10<sup>9</sup>", "ï¿½ 10<sup>9</sup>"]),
        "npu03568": ("10^9/l", ["10e9/l", "x 10e9/l", "× 10^9/l", "× 10<sup>9</sup>", "ï¿½ 10<sup>9</sup>"]),
        "npu02508": ("umol/L", ["ï¿½mol/l", "?mol/l", "âµmol/l"]),
        "npu01370": ("umol/L", ["ï¿½mol/l", "?mol/l", "âµmol/l"])
    }
    def check_and_replace(row):
        analysis_code = row['analysiscode']
        unit_value = row['unit']
    
        if analysis_code in unit_mapping:
            default_unit, valid_alternatives = unit_mapping[analysis_code]
            if unit_value != default_unit:
                if unit_value in valid_alternatives:
                    return default_unit
                else:
                    return None
    
        return unit_value
    df['unit'] = df.apply(check_and_replace, axis=1)
    df = df.dropna(subset=['unit'])
    
    return df

def transform_value(value, transformation):
    if transformation == 'ln(x+1)':

        return np.log1p(value) if value is not None else np.nan
    elif transformation == 'sqrt(x)':
        return np.sqrt(value) if value is not None else np.nan
    return value


def remove_outliers(df, col):
    """ Replace tuples with (NaN, NaN, NaN, NaN, NaN) where the value (index 0) is an outlier. """
    values = df[col].apply(lambda x: x[0] if x[0] is not None else np.nan)
    mean = values.mean()
    std = values.std()
    
    outlier_mask = (values < mean - 3 * std) | (values > mean + 3 * std)
    df[col] = df[col].mask(outlier_mask, df[col].map(lambda x: (np.nan, np.nan, np.nan, np.nan, np.nan)))

    return df

def z_score_normalize(df, col, index, mean, std):
    df[col] = df[col].apply(lambda x: (
        ((x[0] - mean) / std if std != 0 else mean) if index == 0 else x[0],
        x[1],  # unit
        x[2],  # laboratorium_idcode
        ((x[3] - mean) / std if std != 0 else mean) if index == 3 else x[3],
        ((x[4] - mean) / std if std != 0 else mean) if index == 4 else x[4]
    ))

    temp_df = df[col].apply(pd.Series)

    for idx in [3, 4]:
        col_mean = temp_df[idx].mean(skipna=True)  # Compute mean without NaNs
        temp_df[idx].fillna(col_mean, inplace=True)

    df[col] = temp_df.apply(tuple, axis=1)
    
    return df

def transform_normalise_biomarker_features(train_df, val_df, test_df):
    npu_transforms = {
        'npu19748': ('ln(x+1)', 'z-score'),   # CRP
        'npu19717': ('ln(x+1)', 'z-score'),   # F-cal
        'npu02593': ('ln(x+1)', 'z-score'),   # Leukocytes
        'npu02902': ('ln(x+1)', 'z-score'),   # Neutrophils
        'npu02636': ('ln(x+1)', 'z-score'),   # Lymphocytes
        'npu02840': ('ln(x+1)', 'z-score'),   # Monocytes
        'npu01933': ('sqrt(x)', 'z-score'),   # Eosinophils
        'npu01349': ('sqrt(x)', 'z-score'),   # Basophils
        'npu03568': ('sqrt(x)', 'z-score'),   # Platelets
        'npu02319': (None, 'z-score'),        # Hemoglobin
        'npu02508': ('ln(x+1)', 'z-score'),   # Iron
        'npu02070': ('ln(x+1)', 'z-score'),   # Folate
        'npu01700': ('ln(x+1)', 'z-score'),   # Vitamin B12
        'npu10267': ('ln(x+1)', 'z-score'),   # Vitamin D2+D3
        'npu19651': ('ln(x+1)', 'z-score'),   # ALAT
        'npu19673': (None, 'z-score'),        # Albumin
        'npu01370': ('ln(x+1)', 'z-score'),   # Bilirubin
    }

    train_stats = {}

    for col, (transformation, norm_type) in npu_transforms.items():
        if col in train_df.columns:
            print(f"Processing {col}")

            # Remove outliers in train set only
            train_df = remove_outliers(train_df, col)

            # Apply transformations (log/sqrt) to all three sets
            for df in [train_df, val_df, test_df]:
                if transformation:
                    df[col] = df[col].apply(lambda x: (
                        transform_value(x[0], transformation),
                        x[1],
                        x[2],
                        transform_value(x[3], transformation),
                        transform_value(x[4], transformation)
                    ))

            # Compute train stats and normalize all sets using train stats
            if norm_type == 'z-score':
                for index in [0, 3, 4]:
                    train_values = train_df[col].apply(lambda x: x[index] if x[index] is not None else np.nan)
                    train_mean = train_values.mean()
                    train_std = train_values.std()

                    train_stats[(col, index)] = (train_mean, train_std)

                    for df in [train_df, val_df, test_df]:
                        df = z_score_normalize(df, col, index, train_mean, train_std)
                        if df is train_df:
                            train_df = df
                        elif df is val_df:
                            val_df = df
                        else:
                            test_df = df

            # Special handling for npu19748 (CRP)
            if col == 'npu19748':
                for df in [train_df, val_df, test_df]:
                    df[col] = df[col].apply(lambda x: (
                        x[0], x[1], x[2],
                        0 if (pd.isna(x[3]) and not pd.isna(x[0])) else x[3],
                        x[4]
                    ))

    return train_df, val_df, test_df
