import pandas as pd
import numpy as np


def check_data(df):
    """
    Check and process the dataset to ensure it meets the requirements for anomaly detection.
    requirements:
        - anomaly ratio <= 1/3
        - 100 <= dataset size <= 50000
        - no NA values
    Args:
        df (pd.DataFrame): The input dataset containing 'label' column.
    Returns:
        pd.DataFrame: The processed dataset with the following modifications:
            - Dropped rows with any NA values
            - If anomaly is >= 1/3, sampled anomaly to 1/3 of total
            - If dataset size is too large (>50000), sampled to 50000 while preserving class ratio
        None: 
            - if dataset size < 100 after such operations, return None
    """
    
    np.random.seed(42)

    # Ensure label is numeric for logic
    if 'label' in df.columns:
        df['label'] = pd.to_numeric(df['label'], errors='coerce')
        
        # Count anomaly and normal
        n_total = len(df)
        n_anomaly = (df['label'] == 1).sum()
        n_normal = (df['label'] == 0).sum()
        print(f"Initial class distribution: normal={n_normal}, anomaly={n_anomaly}, total={n_total}")
        
        # Drop rows with any NA values before saving
        df = df.dropna()
        n_total = len(df)
        n_anomaly = (df['label'] == 1).sum()
        n_normal = (df['label'] == 0).sum()
        print(f"after dropping na: normal={n_normal}, anomaly={n_anomaly}, total={n_total}")

        # If anomaly is >= 1/3, sample anomaly to 1/3 of total
        if n_anomaly / n_total >= 1/3:
            n_anomaly_target = n_normal // 2
            anomaly_idx = df[df['label'] == 1].index
            sampled_idx = np.random.choice(anomaly_idx, size=n_anomaly_target, replace=False)
            df = pd.concat([
                df[df['label'] == 0],
                df.loc[sampled_idx]
            ], axis=0).sort_index()
            # Recalculate total after sampling
            n_total = len(df)
            n_anomaly = (df['label'] == 1).sum()
            n_normal = (df['label'] == 0).sum()
            print(f"Anomaly was >=1/3. Sampled anomaly to {n_anomaly_target}. New class distribution: normal={n_normal}, anomaly={n_anomaly}, total={n_total}")
        
        # Check minimum dataset size
        if n_total < 100:
            print(f"Dataset size after sampling is less than 100 (current size: {n_total}). Stopping execution.")
            # exit()
            return None
        
        # if dataset size is too large (>50000), sample to 50000 while preserving class ratio
        if n_total > 50000:
            print(f"Dataset size is too large ({n_total}>50000)")
            # Calculate original class ratios
            original_anomaly_ratio = n_anomaly / n_total
            original_normal_ratio = n_normal / n_total
            
            # Calculate target counts for each class
            target_anomaly = int(50000 * original_anomaly_ratio)
            target_normal = 50000 - target_anomaly
            
            # Sample each class separately to maintain ratio
            anomaly_df = df[df['label'] == 1].sample(n=min(target_anomaly, n_anomaly), random_state=42)
            normal_df = df[df['label'] == 0].sample(n=min(target_normal, n_normal), random_state=42)
            
            # Combine the sampled data
            df = pd.concat([normal_df, anomaly_df], axis=0).reset_index(drop=True)
            
            n_total = len(df)
            n_anomaly = (df['label'] == 1).sum()
            n_normal = (df['label'] == 0).sum()
            print(f"Sampled to {n_total} while preserving class ratio. New class distribution: normal={n_normal}, anomaly={n_anomaly}, total={n_total}")
        
        print(f"Final class distribution: normal={n_normal}, anomaly={n_anomaly}, anomaly_ratio={n_anomaly/n_total}, total={n_total}")
        print(f"final df shape: {df.shape}")
        return df