import pandas as pd
import numpy as np
import argparse
from sdv.single_table import CTGANSynthesizer, GaussianCopulaSynthesizer, TVAESynthesizer
from sdv.metadata import SingleTableMetadata
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
import hashlib
import random
from sklearn import metrics
from sdv.metadata import Metadata
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler,PolynomialFeatures
import warnings
import time
from xgboost import XGBRegressor
from tqdm import tqdm
from datetime import datetime
warnings.simplefilter("always", FutureWarning)  # Make sure FutureWarning is always shown
random_seed = 42
np.random.seed(random_seed)
random.seed(random_seed)


def hash_to_bit(value):
    """Hash function mapping data to 0 or 1."""
    value_str = str(value).encode('utf-8')
    hash_hex = hashlib.sha256(value_str).hexdigest()
    return int(hash_hex[-1], 16) % 2

def meets_threshold(row, label_col, threshold=1/2):
    """Check if enough non-label cells hash to 0 to meet the threshold."""
    valid_cells = sum(1 for col in row.index
                     if col == label_col or hash_to_bit(row[col]) == 0)
    return valid_cells / (len(row) - (1 if label_col in row.index else 0)) >= threshold


def constrained_sampling(synthesizer, num_samples, label_col, max_attempts_per_row=100, threshold=1/2):
    """
    Generate samples ensuring at least threshold fraction of non-label cells hash to 0.
    The label column is never modified.
    """
    samples = []
    start_time = time.time()

    # tqdm progress bar over the number of samples to be collected
    with tqdm(total=num_samples, desc="Sampling", unit="sample") as pbar:
        while len(samples) < num_samples:
            row = synthesizer.sample(1).iloc[0]
            attempts = 0

            while attempts < max_attempts_per_row:
                if meets_threshold(row, label_col, threshold):
                    samples.append(row)
                    pbar.update(1)  # update progress bar
                    break

                row = synthesizer.sample(1).iloc[0]
                attempts += 1

            if attempts >= max_attempts_per_row:
                print(f"Warning: Max attempts reached for row {len(samples)+1}")
                samples.append(row)

    elapsed_time = time.time() - start_time
    # print(f"Generated {num_samples} constrained samples in {elapsed_time:.2f} seconds.")
    return pd.DataFrame(samples), elapsed_time

def compute_z_score(df, label_col):
    """
    Computes custom z-scores for the count of entries that hash to 0 in each column,
    excluding the specified label column.

    Z-score for a column is defined as:
        2 * sqrt(n) * (hash_count / n - 1/2)

    Parameters:
    - df: pandas DataFrame containing the data
    - label_col: string, name of the label column to exclude

    Returns:
    - z_values: np.array of z-scores
    - counts: np.array of hash counts
    - avg_z: float
    - std_z: float
    """
    n = len(df)

    if n == 0:
        raise ValueError("DataFrame is empty.")

    z_scores = {}
    hash_counts = {}
    valid_columns = [col for col in df.columns if col != label_col]

    for col in valid_columns:
        hash_count = sum(1 for val in df[col] if hash_to_bit(val) == 0)
        z = 2 * np.sqrt(n) * (hash_count / n - 0.5)
        z_scores[col] = z
        hash_counts[col] = hash_count

    counts = np.array(list(hash_counts.values()))
    z_values = np.array(list(z_scores.values()))
    avg_z = np.mean(z_values)
    std_z = np.std(z_values)

        # Assemble into a pandas DataFrame for clean printing
    stats_df = pd.DataFrame({
        'Column': valid_columns,
        'Hash Count (0s)': [hash_counts[col] for col in valid_columns],
        'Z-Score': [z_scores[col] for col in valid_columns]
    })
    return stats_df

def main():
    classifier_label=''
    unique_classes = []
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--generator', type=str, nargs='?', choices=['CTGAN', 'TVAE', 'GC'], help='Type of generator')
    parser.add_argument('-c', '--classifier', type=str, required=True, choices=['Ridge','XGB','RF'], help='Type of Regression Classifier')
    parser.add_argument('-f', '--dataset', type=str, required=True, default='King', choices=['King'], help='Dataset')
    parser.add_argument('-i', '--iteration', type=int, default=10, help='Iteration Count')
    parser.add_argument('-t', '--threshold', type=str, required=True, choices=['1/4','1/3','1/2','2/3','3/4','1'], help='Threshold for Constrained Sampling')
    args = parser.parse_args()

    if args.dataset == "King":
        train_url = "king/king.csv"
        train_data = pd.read_csv(train_url)
        classifier_label = "price"

    if args.threshold == "1":
        constrained = False
    elif args.threshold == "1/4":
        constrained = True
        threshold = 1/4
    elif args.threshold == "1/3":
        constrained = True
        threshold = 1/3
    elif args.threshold == "1/2":
        constrained = True
        threshold = 1/2
    elif args.threshold == "2/3":
        constrained = True
        threshold = 2/3
    elif args.threshold == "3/4":
        constrained = True
        threshold = 3/4
    print("Classifier:", args.classifier, "Synthesizer:", args.generator, "Iteration:",args.iteration, "Dataset:",args.dataset)
    print("Constrained Sampling:", constrained, "Threshold:", threshold)

    accuracies_synthetic = []
    accuracies_watermarked = []
    accuracies_original = []
    constrained_time = []
    z_score = []

    for iteration in range(args.iteration):
        print(f"Iteration {iteration + 1}/",args.iteration)
        train_data = pd.read_csv(train_url)

        train_data['date'] = train_data['date'].str.split('T').str[0].astype(int)
        features =["floors", "waterfront","lat" ,"bedrooms" ,"sqft_basement" ,"view" ,"bathrooms","sqft_living15","sqft_above","grade","sqft_living","price"]
        train_data = train_data[features]
        train_data, test_data = train_test_split(train_data, test_size=0.25)

        # Separate features and target
        X_train = train_data.drop(classifier_label, axis=1)
        y_train = train_data[classifier_label]
        X_test = test_data.drop(classifier_label, axis=1)
        y_test = test_data[classifier_label]

        metadata = Metadata.detect_from_dataframe(data=train_data)

        # Computing Performance of the Original Dataset

        if args.classifier=="Ridge":
            model= Ridge(alpha=0.1)
            model.fit(X_train,y_train)
            acc_linreg=model.score(X_test, y_test)
            print("Original Ridge Model Accuracy:",acc_linreg)
        elif args.classifier=="XGB":
            reg = XGBRegressor()
            reg.fit(X_train,y_train)
            y_pred = reg.predict(X_test)
            acc_linreg = metrics.r2_score(y_test, y_pred)
            print("Original XGB Model Accuracy:",acc_linreg)
        elif args.classifier=="RF":
            reg = RandomForestRegressor()
            reg.fit(X_train,y_train)
            y_pred = reg.predict(X_test)
            acc_linreg = metrics.r2_score(y_test, y_pred)
            print("Original RF Model Accuracy:",acc_linreg)

        accuracies_original.append(acc_linreg)

        # Train the generator
        if args.generator == 'CTGAN':
            generator = CTGANSynthesizer(metadata)
            generator.fit(train_data)
        elif args.generator == 'GC':
            generator = GaussianCopulaSynthesizer(metadata)
            generator.fit(train_data)
        elif args.generator == 'TVAE':
            generator = TVAESynthesizer(metadata)
            generator.fit(train_data)

        print("Generating Synthetic Data")
        synthetic_data = generator.sample(y_train.shape[0])

        X_train_synthetic = synthetic_data.drop(classifier_label, axis=1)
        y_train_synthetic = synthetic_data[classifier_label]

        # Step 5: Train the Classifier on Synthetic Data

        if args.classifier=="Ridge":
            model= Ridge(alpha=0.1)
            model.fit(X_train_synthetic,y_train_synthetic)
            acc_linreg=model.score(X_test, y_test)
            print("Synthetic Ridge Model Accuracy:",acc_linreg)
        elif args.classifier=="XGB":
            reg = XGBRegressor()
            reg.fit(X_train_synthetic,y_train_synthetic)
            y_pred = reg.predict(X_test)
            acc_linreg = metrics.r2_score(y_test, y_pred)
            print("Synthetic XGB Model Accuracy:",acc_linreg)
        elif args.classifier=="RF":
            reg = RandomForestRegressor()
            reg.fit(X_train_synthetic,y_train_synthetic)
            y_pred = reg.predict(X_test)
            acc_linreg = metrics.r2_score(y_test, y_pred)
            print("Synthetic XGB Model Accuracy:",acc_linreg)
        accuracies_synthetic.append(acc_linreg)
        # Creating Watermarked Dataset

        # Option 1: Only modifying floating point values
        if constrained == False:
            watermarked_data = synthetic_data.copy()
            # Iterate over each cell (excluding the classifier_label column)
            # Embed only in floating point columns
            for col in watermarked_data.columns:
                if col != classifier_label:
                    if watermarked_data[col].dtype != 'int64':
                        for idx in watermarked_data.index:
                            with warnings.catch_warnings(record=True) as caught_warnings:
                                cell_value = watermarked_data.at[idx, col]
                                # Increment the cell value until it hashes to 0
                                while hash_to_bit(cell_value) == 1:
                                    cell_value += 0.000001
                                watermarked_data.at[idx, col] = cell_value
                                if caught_warnings:
                                    print(f"FutureWarning in column: {col}, row: {idx}, type:{watermarked_data[col].dtype}")
                                    for warning in caught_warnings:
                                        print(f"Warning message: {warning.message}")
        else:
            # Not every cell is going to hash to 0. Only certain threshold of them, per row.
            watermarked_data, elapsed_time = constrained_sampling(generator, y_train.shape[0],classifier_label, threshold=threshold)
            # Use this to compute z-score of each column
            stats_df = compute_z_score(watermarked_data,classifier_label)

            # Compute summary statistics
            z_values = stats_df['Z-Score'].values
            counts = stats_df['Hash Count (0s)'].values
            avg_z = np.mean(z_values)
            std_z = np.std(z_values)
            constrained_time.append(elapsed_time)
            z_score.append(avg_z)
            # print(f"  Average Z-Score: {avg_z:.3f} \xB1 {std_z:.3f}")


        watermarked_classes = watermarked_data[classifier_label].unique()
        missing_classes = set(unique_classes) - set(watermarked_classes)
        if missing_classes:
            print("missing ", missing_classes)
            # Add missing classes to synthetic data
            for cls in missing_classes:
                sample = train_data[train_data[classifier_label] == cls].sample(n=1)
                watermarked_data = pd.concat([watermarked_data, sample], ignore_index=True)

        # Use Watermarked Data for Training
        X_train_watermarked = watermarked_data.drop(classifier_label, axis=1)
        y_train_watermarked = watermarked_data[classifier_label]

        # Train the Classifier on Watermarked Data
        if args.classifier=="Ridge":
            model= Ridge(alpha=0.1)
            model.fit(X_train_watermarked,y_train_watermarked)
            acc_linreg=model.score(X_test, y_test)
            print("Watermarked Ridge Model Accuracy:",acc_linreg)
        elif args.classifier=="XGB":
            reg = XGBRegressor()
            reg.fit(X_train_watermarked,y_train_watermarked)
            y_pred = reg.predict(X_test)
            acc_linreg = metrics.r2_score(y_test, y_pred)
            print("Watermarked XGB Model Accuracy:",acc_linreg)
        elif args.classifier=="RF":
            reg = RandomForestRegressor()
            reg.fit(X_train_watermarked,y_train_watermarked)
            y_pred = reg.predict(X_test)
            acc_linreg = metrics.r2_score(y_test, y_pred)
            print("Watermarked XGB Model Accuracy:",acc_linreg)

        accuracies_watermarked.append(acc_linreg)

    # Calculate mean and standard deviation of accuracies
    mean_accuracy_original = np.mean(accuracies_original)
    std_accuracy_original = np.std(accuracies_original)
    mean_accuracy_synthetic = np.mean(accuracies_synthetic)
    std_accuracy_synthetic = np.std(accuracies_synthetic)
    mean_accuracy_watermarked = np.mean(accuracies_watermarked)
    std_accuracy_watermarked = np.std(accuracies_watermarked)


    print("Classifier:", args.classifier, "Synthesizer:", args.generator, "Iteration:",args.iteration, "Dataset:",args.dataset)
    print("Constrained Sampling:", constrained, "Threshold:", threshold)
    print(f"Mean Accuracy \xB1 Standard Deviation (Synthetic Data): {mean_accuracy_synthetic * 100:.2f} \xB1 {std_accuracy_synthetic * 100:.2f}")

    print(f"Mean Accuracy \xB1 Standard Deviation (Watermarked Data): {mean_accuracy_watermarked * 100:.2f} \xB1 {std_accuracy_watermarked * 100:.2f}")
    if constrained == True and args.generator == 'CTGAN' and args.classifier == "XGB":
        mean_elapsed = np.mean(constrained_time)
        std_elapsed = np.std(constrained_time)
        mean_z = np.mean(z_score)
        std_z = np.std(z_score)
        print(f"Mean Constrained Sampling Time \xB1 Standard Deviation: {mean_elapsed:.2f} \xB1 {std_elapsed:.2f}")
        print(f"Mean Z-Score \xB1 Standard Deviation: {mean_z:.2f} \xB1 {std_z:.2f}")


if __name__ == "__main__":
    main()
