import pandas as pd
import numpy as np
import argparse
from sdv.single_table import CTGANSynthesizer, GaussianCopulaSynthesizer, TVAESynthesizer
from sdv.metadata import SingleTableMetadata
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
import hashlib
import random
from sdv.metadata import Metadata
import warnings
import time
from tqdm import tqdm
warnings.simplefilter("always", FutureWarning)  # Make sure FutureWarning is always shown
random_seed = 42
np.random.seed(random_seed)
random.seed(random_seed)


def hash_to_bit(value):
    """Hash function mapping data to 0 or 1."""
    value_str = str(value).encode('utf-8')
    hash_hex = hashlib.sha256(value_str).hexdigest()
    return int(hash_hex[-1], 16) % 2

def meets_threshold(row, label_col, threshold=1/2):
    """Check if enough non-label cells hash to 0 to meet the threshold."""
    valid_cells = sum(1 for col in row.index
                     if col == label_col or hash_to_bit(row[col]) == 0)
    return valid_cells / (len(row) - (1 if label_col in row.index else 0)) >= threshold


def constrained_sampling(synthesizer, num_samples, label_col, max_attempts_per_row=100, threshold=1/2):
    """
    Generate samples ensuring at least threshold fraction of non-label cells hash to 0.
    The label column is never modified.
    """
    samples = []
    start_time = time.time()

    # tqdm progress bar over the number of samples to be collected
    with tqdm(total=num_samples, desc="Sampling", unit="sample") as pbar:
        while len(samples) < num_samples:
            row = synthesizer.sample(1).iloc[0]
            attempts = 0

            while attempts < max_attempts_per_row:
                if meets_threshold(row, label_col, threshold):
                    samples.append(row)
                    pbar.update(1)  # update progress bar
                    break

                row = synthesizer.sample(1).iloc[0]
                attempts += 1

            if attempts >= max_attempts_per_row:
                print(f"Warning: Max attempts reached for row {len(samples)+1}")
                samples.append(row)

    elapsed_time = time.time() - start_time
    # print(f"Generated {num_samples} constrained samples in {elapsed_time:.2f} seconds.")
    return pd.DataFrame(samples), elapsed_time

def compute_z_score(df, label_col):
    """
    Computes custom z-scores for the count of entries that hash to 0 in each column,
    excluding the specified label column.

    Z-score for a column is defined as:
        2 * sqrt(n) * (hash_count / n - 1/2)

    Parameters:
    - df: pandas DataFrame containing the data
    - label_col: string, name of the label column to exclude

    Returns:
    - z_values: np.array of z-scores
    - counts: np.array of hash counts
    - avg_z: float
    - std_z: float
    """
    n = len(df)

    if n == 0:
        raise ValueError("DataFrame is empty.")

    z_scores = {}
    hash_counts = {}
    valid_columns = [col for col in df.columns if col != label_col]

    for col in valid_columns:
        hash_count = sum(1 for val in df[col] if hash_to_bit(val) == 0)
        z = 2 * np.sqrt(n) * (hash_count / n - 0.5)
        z_scores[col] = z
        hash_counts[col] = hash_count

    counts = np.array(list(hash_counts.values()))
    z_values = np.array(list(z_scores.values()))
    avg_z = np.mean(z_values)
    std_z = np.std(z_values)

        # Assemble into a pandas DataFrame for clean printing
    stats_df = pd.DataFrame({
        'Column': valid_columns,
        'Hash Count (0s)': [hash_counts[col] for col in valid_columns],
        'Z-Score': [z_scores[col] for col in valid_columns]
    })
    return stats_df

def main():
    classifier_label=''
    unique_classes = []
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--generator', type=str, nargs='?', default='ORIG', choices=['CTGAN', 'TVAE', 'GC','ORIG'], help='Type of generator')
    parser.add_argument('-c', '--classifier', type=str, required=True, choices=['XGB', 'RF', 'Both'], help='Type of classifier')
    parser.add_argument('-f', '--dataset', type=str, required=True, choices=['Wilt','Housing','HOG','Shopper','King'], help='Dataset')
    parser.add_argument('-i', '--iteration', type=int, default=10, help='Iteration Count')
    parser.add_argument('-t', '--threshold', type=str, required=True, choices=['1/4','1/3','1/2','2/3','3/4','1'], help='Threshold for Constrained Sampling')
    args = parser.parse_args()

    if args.dataset == "Wilt":
        train_url = "wilt/training_complete.csv"
        test_url = "wilt/testing.csv"
        classifier_label='class'
        train_data = pd.read_csv(train_url)
    elif args.dataset == "Housing":
        train_url = "housing/housing.csv"
        classifier_label= 'ocean_proximity'
        train_data = pd.read_csv(train_url)
    elif args.dataset == "HOG":
        train_url = "hog/hog.csv"
        train_data = pd.read_csv(train_url)
        classifier_label = 'class'
        train_data = train_data.drop('id',axis=1)
        classifier_label = train_data.columns.tolist()[-1]
    elif args.dataset == "Shopper":
        train_url = "shopper/shopper.csv"
        train_data = pd.read_csv(train_url)
        classifier_label = "Revenue"

    if args.threshold == "1":
        constrained = False
    elif args.threshold == "1/4":
        constrained = True
        threshold = 1/4
    elif args.threshold == "1/3":
        constrained = True
        threshold = 1/3
    elif args.threshold == "1/2":
        constrained = True
        threshold = 1/2
    elif args.threshold == "2/3":
        constrained = True
        threshold = 2/3
    elif args.threshold == "3/4":
        constrained = True
        threshold = 3/4
    print("Classifier:", args.classifier, "Synthesizer:", args.generator, "Iteration:",args.iteration, "Dataset:",args.dataset)
    print("Constrained Sampling:", constrained, "Threshold:", threshold)

    accuracies_synthetic = []
    accuracies_watermarked = []
    constrained_time = []
    z_score = []

    for iteration in range(args.iteration):
        print(f"Iteration {iteration + 1}/",args.iteration)
        train_data = pd.read_csv(train_url)
        label_encoder = LabelEncoder()
        train_data[classifier_label] = label_encoder.fit_transform(train_data[classifier_label])
        if args.dataset == "Shopper":
            month_label_encoder = LabelEncoder()
            visitor_label_encoder = LabelEncoder()
            weekend_label_encoder = LabelEncoder()
            train_data["Month"] = month_label_encoder.fit_transform(train_data["Month"])
            train_data["VisitorType"] = visitor_label_encoder.fit_transform(train_data["VisitorType"])
            train_data["Weekend"] = weekend_label_encoder.fit_transform(train_data["Weekend"])


        train_data, test_data = train_test_split(train_data, test_size=0.25)

        unique_classes = train_data[classifier_label].unique()

        # Separate features and target
        X_train = train_data.drop(classifier_label, axis=1)
        y_train = train_data[classifier_label]
        X_test = test_data.drop(classifier_label, axis=1)
        y_test = test_data[classifier_label]

        metadata = Metadata.detect_from_dataframe(data=train_data)

        # Train the generator with the best hyperparameters
        if args.generator == 'CTGAN':
            generator = CTGANSynthesizer(metadata)
            generator.fit(train_data)
        elif args.generator == 'GC':
            generator = GaussianCopulaSynthesizer(metadata)
            generator.fit(train_data)
        elif args.generator == 'TVAE':
            generator = TVAESynthesizer(metadata)
            generator.fit(train_data)

         # Synthetic Data is set to be the same as train_data in this case. Testing accuracy for original.

        if args.generator == 'ORIG':
            synthetic_data = train_data.copy()

        # Generate synthetic data
        else:
            synthetic_data = generator.sample(y_train.shape[0])

            synthetic_classes = synthetic_data[classifier_label].unique()
            # Check to make sure that there are no missing classes in the synthetic dataset
            missing_classes = set(unique_classes) - set(synthetic_classes)

            if missing_classes:
            # Add missing classes to synthetic data
                for cls in missing_classes:

                    # We do this by simply sampling from training data
                    sample = train_data[train_data[classifier_label] == cls].sample(n=1)
                    synthetic_data = pd.concat([synthetic_data, sample], ignore_index=True)


        X_train_synthetic = synthetic_data.drop(classifier_label, axis=1)
        y_train_synthetic = synthetic_data[classifier_label]

        # Step 5: Train the Classifier on Synthetic Data
        if args.classifier == "XGB":
            classifier = XGBClassifier()
        elif args.classifier == "RF":
            classifier = RandomForestClassifier(random_state=random_seed)

        classifier.fit(X_train_synthetic, y_train_synthetic)

        # Predict on the test data
        y_pred = classifier.predict(X_test)
        accuracy_synthetic = accuracy_score(y_test, y_pred)
        accuracies_synthetic.append(accuracy_synthetic)

        # Creating Watermarked Dataset

        # Option 1: Only modifying floating point values
        if constrained == False:
            watermarked_data = synthetic_data.copy()
            # Iterate over each cell (excluding the classifier_label column)
            # Embed only in floating point columns
            for col in watermarked_data.columns:
                if col != classifier_label:
                    if watermarked_data[col].dtype != 'int64':
                        for idx in watermarked_data.index:
                            with warnings.catch_warnings(record=True) as caught_warnings:
                                cell_value = watermarked_data.at[idx, col]
                                # Increment the cell value until it hashes to 0
                                while hash_to_bit(cell_value) == 1:
                                    cell_value += 0.000001
                                watermarked_data.at[idx, col] = cell_value
                                if caught_warnings:
                                    print(f"FutureWarning in column: {col}, row: {idx}, type:{watermarked_data[col].dtype}")
                                    for warning in caught_warnings:
                                        print(f"Warning message: {warning.message}")
        else:
            # Not every cell is going to hash to 0. Only certain threshold of them, per row.
            watermarked_data, elapsed_time = constrained_sampling(generator, y_train.shape[0],classifier_label, threshold=threshold)
            # Use this to compute z-score of each column
            stats_df = compute_z_score(watermarked_data,classifier_label)

            # Compute summary statistics
            z_values = stats_df['Z-Score'].values
            counts = stats_df['Hash Count (0s)'].values
            avg_z = np.mean(z_values)
            std_z = np.std(z_values)
            constrained_time.append(elapsed_time)
            z_score.append(avg_z)
            # print(f"  Average Z-Score: {avg_z:.3f} \xB1 {std_z:.3f}")


        watermarked_classes = watermarked_data[classifier_label].unique()
        missing_classes = set(unique_classes) - set(watermarked_classes)
        if missing_classes:
            print("missing ", missing_classes)
            # Add missing classes to synthetic data
            for cls in missing_classes:
                sample = train_data[train_data[classifier_label] == cls].sample(n=1)
                watermarked_data = pd.concat([watermarked_data, sample], ignore_index=True)

        # Use Watermarked Data for Training
        X_train_watermarked = watermarked_data.drop(classifier_label, axis=1)
        y_train_watermarked = watermarked_data[classifier_label]

        # Train the Classifier on Watermarked Data
        if args.classifier == "XGB":
            classifier = XGBClassifier()
        elif args.classifier == "RF":
            classifier = RandomForestClassifier(random_state=random_seed)

        classifier.fit(X_train_watermarked, y_train_watermarked)
        y_pred = classifier.predict(X_test)
        accuracy_watermarked = accuracy_score(y_test, y_pred)
        accuracies_watermarked.append(accuracy_watermarked)

    # Calculate mean and standard deviation of accuracies
    mean_accuracy_synthetic = np.mean(accuracies_synthetic)
    std_accuracy_synthetic = np.std(accuracies_synthetic)
    mean_accuracy_watermarked = np.mean(accuracies_watermarked)
    std_accuracy_watermarked = np.std(accuracies_watermarked)


    print("Classifier:", args.classifier, "Synthesizer:", args.generator, "Iteration:",args.iteration, "Dataset:",args.dataset)
    print("Constrained Sampling:", constrained, "Threshold:", threshold)
    print(f"Mean Accuracy \xB1 Standard Deviation (Synthetic Data): {mean_accuracy_synthetic * 100:.2f} \xB1 {std_accuracy_synthetic * 100:.2f}%")

    print(f"Mean Accuracy \xB1 Standard Deviation (Watermarked Data): {mean_accuracy_watermarked * 100:.2f} \xB1 {std_accuracy_watermarked * 100:.2f}%")

    # Measuring Z-Score and Sampling Time for Constrained Sampling
    if constrained == True and args.generator == 'CTGAN' and args.classifier == "XGB":
        mean_elapsed = np.mean(constrained_time)
        std_elapsed = np.std(constrained_time)
        mean_z = np.mean(z_score)
        std_z = np.std(z_score)
        print(f"Mean Constrained Sampling Time \xB1 Standard Deviation: {mean_elapsed:.2f} \xB1 {std_elapsed:.2f}seconds")
        print(f"Mean Z-Score \xB1 Standard Deviation: {mean_z:.2f} \xB1 {std_z:.2f}")

if __name__ == "__main__":
    main()
