import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, GPTNeoXForSequenceClassification
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
from datasets import Dataset
import numpy as np
import os

os.environ["PYTORCH_USE_CUDA_DSA"] = "1"
OUTPUT_DIR = "wts_results_no_threshold/"
SCRATCH_OUTPUT_DIR = ""

# Strong ceiling benchmark for weak-to-strong performance

income_threshold = 50000 # threshold for income classification
output_var = 'income'
sensitive_attr = 'gender'
filepath = "adult_reconstruction.csv"
epochs = 5
strong_learning_rate = 5e-5
train_prop = 0.4
holdout_prop = 0.4
validation_prop = 0.2 # 20% of the training dataset

class FairnessTrainer(Trainer):
    def __init__(self, args, model, train_dataset, eval_dataset, lambda_reg, output_var, sensitive_attr, data_collator):
        super().__init__(model, args = args, train_dataset = train_dataset, eval_dataset = eval_dataset)
        self.lambda_reg = lambda_reg
        self.output_var = output_var
        self.sensitive_attr = sensitive_attr
        self.data_collator = data_collator

    def compute_loss(self, model, inputs, return_outputs = False):
        z = inputs.pop(self.sensitive_attr)
        outputs = model(**inputs)
        loss = outputs.loss.float()
        pred_values = torch.softmax(outputs.logits, dim = 1)[:, 1]
        reg_term = self.compute_regularization_term(pred_values, z)
        loss_with_reg = loss + self.lambda_reg * reg_term
        return (loss_with_reg, outputs) if return_outputs else loss_with_reg

    def compute_regularization_term(self, pred_values, z):
        return abs(((z - z.mean()) * pred_values).mean())
    
    def predict(self, test_dataset, ignore_keys=None, metric_key_prefix: str = "test"):
        if ignore_keys is None:
            ignore_keys = ["past_key_values"]
        return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
    
    def evaluate(self, ignore_keys=None, metric_key_prefix: str = "eval"):
        if ignore_keys is None:
            ignore_keys = ["past_key_values"]
        return super().evaluate(ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

# drop columns: education for redundant data
# encode gender (male = 1, female = 0), income (<=50K = 0, >50K = 1)
def preprocess_data(df, output_var, sensitive_attr, income_threshold):
    df = df.drop(columns=['education'])
    df[sensitive_attr] = df[sensitive_attr].apply(lambda x: 1 if x == 'Male' else 0)
    df[output_var] = df[output_var].apply(lambda x: 1 if int(x) > income_threshold else 0)
    return df

def load_and_preprocess_data(filepath, output_var, sensitive_attr, model_name, income_threshold):
    df = pd.read_csv(filepath, header=0, delimiter=',')
    df = preprocess_data(df, output_var, sensitive_attr, income_threshold)

    df_0 = df[df[output_var] == 0]
    df_1 = df[df[output_var] == 1]
    min_len = min(len(df_0), len(df_1))
    df_0 = df_0.sample(n=min_len)
    df_1 = df_1.sample(n=min_len)
    balanced_df = pd.concat([df_0, df_1])
    balanced_df = balanced_df.sample(frac=1)
    df = balanced_df
    
    features_df = df.drop(output_var, axis=1)
    outputs_df = df[output_var]
    train_dataset_tokenized = tokenize_dataset(features_df.iloc[:int(len(df)*train_prop*(1-validation_prop))], outputs_df.iloc[:int(len(df)*train_prop*(1-validation_prop))].values, sensitive_attr, model_name, income_threshold)
    validation_dataset_tokenized = tokenize_dataset(features_df.iloc[int(len(df)*train_prop*(1-validation_prop)):int(len(df)*(train_prop))], outputs_df.iloc[int(len(df)*train_prop*(1-validation_prop)):int(len(df)*(train_prop))].values, sensitive_attr, model_name, income_threshold)
    test_dataset_tokenized = tokenize_dataset(features_df.iloc[int(len(df)*(train_prop+holdout_prop)):], outputs_df.iloc[int(len(df)*(train_prop+holdout_prop)):].values, sensitive_attr, model_name, income_threshold)
    holdout_features = features_df.iloc[int(len(df)*(train_prop)):int(len(df)*(train_prop+holdout_prop))]
    holdout_dataset_tokenized = tokenize_dataset(features_df.iloc[int(len(df)*(train_prop)):int(len(df)*(train_prop+holdout_prop))], outputs_df.iloc[int(len(df)*(train_prop)):int(len(df)*(train_prop+holdout_prop))].values, sensitive_attr, model_name, income_threshold)
    return train_dataset_tokenized, test_dataset_tokenized, validation_dataset_tokenized, holdout_dataset_tokenized, holdout_features

# takes in Dataframe features_df and list outputs (labels)
def tokenize_dataset(features_df, outputs, sensitive_attr, model_name, income_threshold):
    count = 0
    texts = []
    for x in features_df.values:
        converted_str = "An individual with "
        for i in range(len(x)):
            if features_df.columns[i] == sensitive_attr:
                if x[i] == 0:
                    converted_str += sensitive_attr + ": female, "
                elif x[i] == 1:
                    converted_str += sensitive_attr + ": male, "
            else:
                converted_str += features_df.columns[i] + ": " + str(x[i]) + ", "
        converted_str += "predict whether the individual's income is greater than $" + str(income_threshold)
        # three values in texts: converted prompt string with attributes, sensitive attribute (gender), and output label (income)
        texts.append([converted_str, int(features_df[sensitive_attr].values[count]), int(outputs[count])])
        count += 1
    texts_np = np.array(texts)
    texts_df = pd.DataFrame(texts_np, columns = ['text', sensitive_attr, 'label'])
    texts_df = texts_df.astype({'label': int, sensitive_attr: float})
    dataset = Dataset.from_pandas(texts_df)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    def tokenize_text(batch):
        return tokenizer(batch["text"], truncation=True, max_length=1024)
    dataset_tokenized = dataset.map(tokenize_text, batched=True)
    dataset_tokenized.set_format("torch", columns=["input_ids", sensitive_attr, "attention_mask", "label"])
    return dataset_tokenized

def calculate_stats(dataset, trainer, sensitive_attr):
    sensitive_0 = 0
    sensitive_1 = 0
    sensitive_0_y1 = 0
    sensitive_1_y1 = 0
    predictions = trainer.predict(dataset)
    predicted_labels = np.argmax(predictions[0].squeeze(), axis=1)
    true_labels = [row["label"] for row in dataset]
    for j in range(len(dataset)):
        sensitive_value = int(dataset[j][sensitive_attr])
        if sensitive_value == 0:
            sensitive_0 += 1
            if predicted_labels[j] == 1:
                sensitive_0_y1 += 1
        elif sensitive_value == 1:
            sensitive_1 += 1
            if predicted_labels[j] == 1:
                sensitive_1_y1 += 1
    if sensitive_1 == 0 or sensitive_0 == 0:
        spd_diff = 'NA'
    else:
        spd_diff = (sensitive_1_y1 / sensitive_1) - (sensitive_0_y1 / sensitive_0)
    accuracy = float(sum([predicted_labels[i]==true_labels[i] for i in range(len(true_labels))])/len(true_labels))
    return spd_diff, accuracy

def gpu_memory():
    """
    Return the GPU memory usage and total memory available in GB for all available GPUs.
    """
    gpu_memory_used = 0
    gpu_memory_total = 0

    for i in range(torch.cuda.device_count()):
        gpu_memory_used += torch.cuda.memory_allocated(i)
        gpu_memory_total += torch.cuda.get_device_properties(i).total_memory

    return gpu_memory_used / 1024**3, gpu_memory_total / 1024**3

def main(lambda_reg, strong_model_name, strong_batch_size, file_name):
    print("Lambda: ")
    print(lambda_reg)
    print("Epochs: ")
    print(epochs)
    print("Strong Ceiling Model: ")
    print(strong_model_name)
    print("Strong Ceiling Batch size: ")
    print(strong_batch_size)

    train_dataset, test_dataset, validation_dataset, holdout_dataset_initial, holdout_features = load_and_preprocess_data(filepath, output_var, sensitive_attr, strong_model_name, income_threshold)
    strong_model = GPTNeoXForSequenceClassification.from_pretrained(strong_model_name, device_map='auto', num_labels=2)
    strong_model.config.pad_token_id = strong_model.config.eos_token_id
    strong_model_id = strong_model_name.split("/")[-1]

    def data_collator(features: list) -> dict:
        input_ids = [item['input_ids'] for item in features]
        attention_mask = [item['attention_mask'] for item in features]
        labels = [item['label'] for item in features]
        sensitive_attr_values = [item[sensitive_attr] for item in features]
        input_ids_padded = pad_sequence([torch.tensor(ids) for ids in input_ids],
                                            batch_first=True, padding_value=strong_model.config.pad_token_id)
        attention_mask_padded = pad_sequence([torch.tensor(ids) for ids in attention_mask],
                                             batch_first=True, padding_value=0)
        labels_batch = torch.tensor(labels, dtype=torch.long)
        sensitive_attr_batch = torch.tensor(sensitive_attr_values, dtype=torch.float)
        return {
            "input_ids": input_ids_padded,
            "attention_mask": attention_mask_padded,
            "labels": labels_batch,
            sensitive_attr: sensitive_attr_batch,
        }
    
    # Strong model finetuning
    strong_training_args = TrainingArguments(
        output_dir=SCRATCH_OUTPUT_DIR+file_name+"_no_threshold",
        learning_rate=strong_learning_rate,
        per_device_train_batch_size=strong_batch_size,
        per_device_eval_batch_size=strong_batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        remove_unused_columns=False,
        logging_steps=1,
        save_strategy="steps",
        save_steps=100,
        evaluation_strategy="steps",
        eval_steps=100,
        load_best_model_at_end=True,
        save_total_limit=2
    )
    strong_trainer = FairnessTrainer(
        model=strong_model,
        args=strong_training_args,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        lambda_reg=lambda_reg,
        output_var=output_var,
        sensitive_attr = sensitive_attr,
        data_collator=data_collator,
    )
    f = open(OUTPUT_DIR+"/"+file_name+".csv", "w")
    f.write("Train SPD,Test SPD,Train Accuracy,Test Accuracy\n")
    print("Train SPD,Test SPD,Train Accuracy,Test Accuracy")
    strong_trainer.train()
    spd_diff_test, accuracy_test = calculate_stats(test_dataset, strong_trainer, sensitive_attr)
    spd_diff_train, accuracy_train = calculate_stats(train_dataset, strong_trainer, sensitive_attr)
    row_out = ""
    # Train SPD,Test SPD,Train Accuracy,Test Accuracy
    row_out += str(spd_diff_train)+','+str(spd_diff_test)+','+str(accuracy_train)+','+str(accuracy_test)
    print(row_out)
    f.write(row_out+"\n")
    f.flush()
    strong_trainer.save_model(SCRATCH_OUTPUT_DIR+file_name+"strong_no_threshold")
        
    f.close()


if __name__ == "__main__":
    import sys
    from itertools import product
    strong_model_name = str(sys.argv[1])
    lambda_index = float(sys.argv[2])
    strong_batch_size = int(sys.argv[3])
    file_name = str(sys.argv[4])
    lambda_reg = lambda_index/4
    main(lambda_reg, strong_model_name, strong_batch_size, file_name)