import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, GPTNeoXForSequenceClassification
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
from datasets import Dataset
import numpy as np
import os

# Phase 1 of all WTS experiments (WTS-Naive and WTS-Auxiliary): basic WTS with no FairnessTrainer (lambda = 0 always)

OUTPUT_DIR = "wts_phase1_results/"
SCRATCH_OUTPUT_DIR = ""

# Runs weak-to-strong generalization for given weak and strong models, but with balanced dataset created from sampling
income_threshold = 50000 # threshold for income classification
output_var = 'income'
sensitive_attr = 'gender'
filepath = "adult_reconstruction.csv"
train_prop = 0.4
holdout_prop = 0.4
validation_prop = 0.2 # 20% of the training dataset

class FairnessTrainer(Trainer):
    def __init__(self, args, model, train_dataset, eval_dataset, lambda_reg, output_var, sensitive_attr, wts_auxiliary_bool, alpha_max, burn_in_period, data_collator):
        super().__init__(model, args = args, train_dataset = train_dataset, eval_dataset = eval_dataset)
        self.lambda_reg = lambda_reg
        self.output_var = output_var
        self.sensitive_attr = sensitive_attr
        self.data_collator = data_collator
        self.wts_auxiliary_bool = wts_auxiliary_bool
        self.alpha_max = alpha_max
        self.burn_in_period = burn_in_period
        self.class_wts = {0:0.5, 1:0.5}

    def compute_loss(self, model, inputs, return_outputs = False):
        z = inputs.pop(self.sensitive_attr)
        outputs_orig = model(**inputs)
        loss_orig = outputs_orig.loss.float()
        logits_orig = outputs_orig.logits
        pred_values = torch.softmax(logits_orig, dim = 1)[:, 1]
        reg_term = self.compute_regularization_term(pred_values, z)
        loss_with_reg = loss_orig + self.lambda_reg * reg_term
        if self.wts_auxiliary_bool:
            alpha = self.alpha_max * min(1.0, self.state.global_step / (0.2 * self.state.max_steps))
            num_instances = logits_orig.size(0)
            logit_labels = [None] * num_instances
            remaining_fraction = 1.0
            for label_str in self.class_wts:
                class_fraction = min(self.class_wts[label_str] / remaining_fraction, 1.0)
                label = int(label_str)
                remaining_logits = np.array([logits_orig[i, label].item() for i in range(num_instances) if logit_labels[i] is None])
                threshold = np.quantile(remaining_logits, 1.0 - class_fraction)
                for i in range(num_instances):
                    if logit_labels[i] is None and logits_orig[i, label] > threshold:
                        logit_labels[i] = label
                remaining_fraction -= self.class_wts[label_str]
            for i in range(num_instances):
                if logit_labels[i] is None:
                    logit_labels[i] = label
            logit_labels = torch.tensor(logit_labels, device=logits_orig.device)
            aux_loss_orig = torch.nn.functional.cross_entropy(logits_orig, logit_labels)
            aux_reg_term = self.compute_regularization_term(logit_labels, z)
            aux_loss_with_reg = aux_loss_orig + self.lambda_reg * aux_reg_term
            loss_with_reg = ((1 - alpha) * loss_with_reg) + (alpha * aux_loss_with_reg)
        return (loss_with_reg, outputs_orig) if return_outputs else loss_with_reg

    def compute_regularization_term(self, pred_values, z):
        return abs(((z - z.mean()) * pred_values).mean())
    
    def predict(self, test_dataset, ignore_keys=None, metric_key_prefix: str = "test"):
        if ignore_keys is None:
            ignore_keys = ["past_key_values"]
        return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
    
    def evaluate(self, ignore_keys=None, metric_key_prefix: str = "eval"):
        if ignore_keys is None:
            ignore_keys = ["past_key_values"]
        return super().evaluate(ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

# drop columns: education for redundant data
# encode gender (male = 1, female = 0), income (<=50K = 0, >50K = 1)
def preprocess_data(df, output_var, sensitive_attr, income_threshold):
    df = df.drop(columns=['education'])
    df[sensitive_attr] = df[sensitive_attr].apply(lambda x: 1 if x == 'Male' else 0)
    df[output_var] = df[output_var].apply(lambda x: 1 if int(x) > income_threshold else 0)
    return df

def load_and_preprocess_data(filepath, output_var, sensitive_attr, income_threshold, tokenizer):
    df = pd.read_csv(filepath, header=0, delimiter=',')
    df = preprocess_data(df, output_var, sensitive_attr, income_threshold)
    df_0 = df[df[output_var] == 0]
    df_1 = df[df[output_var] == 1]
    min_len = min(len(df_0), len(df_1))
    df_0 = df_0.sample(n=min_len)
    df_1 = df_1.sample(n=min_len)
    balanced_df = pd.concat([df_0, df_1])
    balanced_df = balanced_df.sample(frac=1)
    df = balanced_df
    features_df = df.drop(output_var, axis=1)
    outputs_df = df[output_var]
    train_dataset_tokenized = tokenize_dataset(features_df.iloc[:int(len(df)*train_prop*(1-validation_prop))], outputs_df.iloc[:int(len(df)*train_prop*(1-validation_prop))].values, sensitive_attr, income_threshold, tokenizer)
    weak_validation_dataset_tokenized = tokenize_dataset(features_df.iloc[int(len(df)*train_prop*(1-validation_prop)):int(len(df)*(train_prop))], outputs_df.iloc[int(len(df)*train_prop*(1-validation_prop)):int(len(df)*(train_prop))].values, sensitive_attr, income_threshold, tokenizer)
    test_dataset_tokenized = tokenize_dataset(features_df.iloc[int(len(df)*(train_prop+holdout_prop)):], outputs_df.iloc[int(len(df)*(train_prop+holdout_prop)):].values, sensitive_attr, income_threshold, tokenizer)
    holdout_features = features_df.iloc[int(len(df)*(train_prop)):int(len(df)*(train_prop+holdout_prop))]
    holdout_dataset_tokenized = tokenize_dataset(features_df.iloc[int(len(df)*(train_prop)):int(len(df)*(train_prop+holdout_prop))], outputs_df.iloc[int(len(df)*(train_prop)):int(len(df)*(train_prop+holdout_prop))].values, sensitive_attr, income_threshold, tokenizer)
    return train_dataset_tokenized, test_dataset_tokenized, weak_validation_dataset_tokenized, holdout_dataset_tokenized, holdout_features

# takes in Dataframe features_df and list outputs (labels)
def tokenize_dataset(features_df, outputs, sensitive_attr, income_threshold, tokenizer):
    count = 0
    texts = []
    for x in features_df.values:
        converted_str = "An individual with "
        for i in range(len(x)):
            if features_df.columns[i] == sensitive_attr:
                if x[i] == 0:
                    converted_str += sensitive_attr + ": female, "
                elif x[i] == 1:
                    converted_str += sensitive_attr + ": male, "
            else:
                converted_str += features_df.columns[i] + ": " + str(x[i]) + ", "
        converted_str += "predict whether the individual's income is greater than $" + str(income_threshold)
        # three values in texts: converted prompt string with attributes, sensitive attribute (gender), and output label (income)
        texts.append([converted_str, int(features_df[sensitive_attr].values[count]), int(outputs[count])])
        count += 1
    texts_np = np.array(texts)
    texts_df = pd.DataFrame(texts_np, columns = ['text', sensitive_attr, 'label'])
    texts_df = texts_df.astype({'label': int, sensitive_attr: float})
    dataset = Dataset.from_pandas(texts_df)

    def tokenize_text(batch):
        return tokenizer(batch["text"], truncation=True, max_length=1024)
    dataset_tokenized = dataset.map(tokenize_text, batched=True)
    dataset_tokenized.set_format("torch", columns=["input_ids", sensitive_attr, "attention_mask", "label"])
    return dataset_tokenized

def calculate_stats(dataset, trainer, sensitive_attr):
    sensitive_0 = 0
    sensitive_1 = 0
    sensitive_0_y1 = 0
    sensitive_1_y1 = 0
    predictions = trainer.predict(dataset)
    predicted_labels = np.argmax(predictions[0].squeeze(), axis=1)
    true_labels = [row["label"] for row in dataset]
    for j in range(len(dataset)):
        sensitive_value = int(dataset[j][sensitive_attr])
        if sensitive_value == 0:
            sensitive_0 += 1
            if predicted_labels[j] == 1:
                sensitive_0_y1 += 1
        elif sensitive_value == 1:
            sensitive_1 += 1
            if predicted_labels[j] == 1:
                sensitive_1_y1 += 1
    if sensitive_1 == 0 or sensitive_0 == 0:
        spd_diff = 'NA'
    else:
        spd_diff = (sensitive_1_y1 / sensitive_1) - (sensitive_0_y1 / sensitive_0)
    accuracy = float(sum([predicted_labels[i]==true_labels[i] for i in range(len(true_labels))])/len(true_labels))
    return spd_diff, accuracy

def get_holdout_df(weak_trainer, holdout_dataset_initial, holdout_features, tokenizer):
    holdout_predictions = weak_trainer.predict(holdout_dataset_initial)
    holdout_predicted_labels = np.argmax(holdout_predictions[0].squeeze(), axis=1)
    holdout_dataset_tokenized = tokenize_dataset(holdout_features.iloc[:int(len(holdout_features)*(1-validation_prop))], holdout_predicted_labels[:int(len(holdout_features)*(1-validation_prop))], sensitive_attr, income_threshold, tokenizer)
    # get validation dataset as a subset of holdout dataset
    wts_naive_validation_dataset_tokenized = tokenize_dataset(holdout_features.iloc[int(len(holdout_features)*(1-validation_prop)):], holdout_predicted_labels[int(len(holdout_features)*(1-validation_prop)):], sensitive_attr, income_threshold, tokenizer)
    wts_naive_holdout_dataset_true = holdout_dataset_initial.select(range(0, int(len(holdout_features)*(1-validation_prop))))
    return holdout_dataset_tokenized, wts_naive_validation_dataset_tokenized, wts_naive_holdout_dataset_true

def print_args(lambda_reg, weak_model_name, strong_model_name, weak_batch_size, strong_batch_size, epochs, weak_learning_rate, strong_learning_rate, alpha_max, burn_in_period):
    print("Lambda: ")
    print(lambda_reg)
    print("Weak Model: ")
    print(weak_model_name)
    print("Strong Model: ")
    print(strong_model_name)
    print("Weak Batch Size: ")
    print(weak_batch_size)
    print("Strong Batch Size: ")
    print(strong_batch_size)
    print("Epochs: ")
    print(epochs)
    print("Weak Learning Rate: ")
    print(weak_learning_rate)
    print("Strong Learning Rate: ")
    print(strong_learning_rate)
    print("Alpha Max: ")
    print(alpha_max)
    print("Burn-in Period: ")
    print(burn_in_period)

def main(lambda_reg, weak_model_name, strong_model_name, weak_batch_size, strong_batch_size, epochs, weak_learning_rate, strong_learning_rate, alpha_max, burn_in_period, file_name):
    print_args(lambda_reg, weak_model_name, strong_model_name, weak_batch_size, strong_batch_size, epochs, weak_learning_rate, strong_learning_rate, alpha_max, burn_in_period)
    tokenizer = AutoTokenizer.from_pretrained(strong_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    train_dataset, test_dataset, weak_validation_dataset, holdout_dataset_initial, holdout_features = load_and_preprocess_data(filepath, output_var, sensitive_attr, income_threshold, tokenizer)
    weak_model = GPTNeoXForSequenceClassification.from_pretrained(weak_model_name, device_map='auto', num_labels=2)
    weak_model.config.pad_token_id = tokenizer.eos_token_id
    weak_model.resize_token_embeddings(len(tokenizer))
    wts_naive_model = GPTNeoXForSequenceClassification.from_pretrained(strong_model_name, device_map='auto', num_labels=2)
    wts_naive_model.config.pad_token_id = tokenizer.eos_token_id
    wts_auxiliary_model = GPTNeoXForSequenceClassification.from_pretrained(strong_model_name, device_map='auto', num_labels=2)
    wts_auxiliary_model.config.pad_token_id = tokenizer.eos_token_id

    def data_collator(features: list) -> dict:
        input_ids = [item['input_ids'] for item in features]
        attention_mask = [item['attention_mask'] for item in features]
        labels = [item['label'] for item in features]
        sensitive_attr_values = [item[sensitive_attr] for item in features]
        input_ids_padded = pad_sequence([torch.tensor(ids) for ids in input_ids],
                                            batch_first=True, padding_value=tokenizer.pad_token_id)
        attention_mask_padded = pad_sequence([torch.tensor(ids) for ids in attention_mask],
                                             batch_first=True, padding_value=0)
        labels_batch = torch.tensor(labels, dtype=torch.long)
        sensitive_attr_batch = torch.tensor(sensitive_attr_values, dtype=torch.float)
        return {
            "input_ids": input_ids_padded,
            "attention_mask": attention_mask_padded,
            "labels": labels_batch,
            sensitive_attr: sensitive_attr_batch,
        }
    
    # Weak model finetuning
    weak_training_args = TrainingArguments(
        output_dir=SCRATCH_OUTPUT_DIR+file_name+"_weak",
        learning_rate=weak_learning_rate,
        per_device_train_batch_size=weak_batch_size,
        per_device_eval_batch_size=weak_batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        remove_unused_columns=False,
        logging_steps=1,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        save_total_limit=2
    )
    weak_trainer = FairnessTrainer(
        model=weak_model,
        args=weak_training_args,
        train_dataset=train_dataset,
        eval_dataset=weak_validation_dataset,
        lambda_reg=0.0,
        output_var=output_var,
        sensitive_attr=sensitive_attr,
        wts_auxiliary_bool=False,
        alpha_max=0.0,
        burn_in_period=0.0,
        data_collator=data_collator,
    )

    path = os.path.join(OUTPUT_DIR, "wts_phase1_"+str(epochs)+"_"+str(weak_learning_rate)+"_"+str(strong_batch_size)+"_"+str(alpha_max)+"_"+str(burn_in_period))
    os.makedirs(path, exist_ok=True) # make directory if it doesn't exist
    f = open(OUTPUT_DIR+"wts_phase1_"+str(epochs)+"_"+str(weak_learning_rate)+"_"+str(strong_batch_size)+"_"+str(alpha_max)+"_"+str(burn_in_period)+"/"+file_name+".csv", "w")
    f.write("Train SPD,Test SPD,Train Accuracy,Test Accuracy\n")
    print("Train SPD,Test SPD,Train Accuracy,Test Accuracy")
    weak_trainer.train()
    spd_diff_test, accuracy_test = calculate_stats(test_dataset, weak_trainer, sensitive_attr)
    spd_diff_train, accuracy_train = calculate_stats(train_dataset, weak_trainer, sensitive_attr)
    weak_row_out = ""
    # Train SPD,Test SPD,Train Accuracy,Test Accuracy
    weak_row_out += str(spd_diff_train)+','+str(spd_diff_test)+','+str(accuracy_train)+','+str(accuracy_test)
    print("WEAK MODEL DONE")
    print(weak_row_out)
    f.write(weak_row_out+"\n")
    f.flush()

    # Weak to Strong generalization for weak labels (holdout dataset)
    holdout_dataset, wts_naive_validation_dataset, wts_naive_holdout_dataset_true = get_holdout_df(weak_trainer, holdout_dataset_initial, holdout_features, tokenizer)
    # Weak-to-strong naive model finetuning
    wts_naive_training_args = TrainingArguments(
        output_dir=SCRATCH_OUTPUT_DIR+file_name+"_wts_naive",
        learning_rate=strong_learning_rate,
        per_device_train_batch_size=strong_batch_size,
        per_device_eval_batch_size=strong_batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        remove_unused_columns=False,
        logging_steps=1,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        save_total_limit=2
    )
    wts_naive_trainer = FairnessTrainer(
        model=wts_naive_model,
        args=wts_naive_training_args,
        train_dataset=holdout_dataset,
        eval_dataset=wts_naive_validation_dataset,
        lambda_reg=0.0,
        output_var=output_var,
        sensitive_attr=sensitive_attr,
        wts_auxiliary_bool=False,
        alpha_max=0.0,
        burn_in_period=0.0,
        data_collator=data_collator,
    )
    wts_naive_trainer.train()
    spd_diff_test, accuracy_test = calculate_stats(test_dataset, wts_naive_trainer, sensitive_attr)
    spd_diff_train, accuracy_train = calculate_stats(wts_naive_holdout_dataset_true, wts_naive_trainer, sensitive_attr)
    wts_naive_row_out = ""
    wts_naive_row_out += str(spd_diff_train)+','+str(spd_diff_test)+','+str(accuracy_train)+','+str(accuracy_test)
    print("WTS-NAIVE MODEL DONE")
    print(wts_naive_row_out)
    f.write(wts_naive_row_out+"\n")
    f.flush()

    # WTS Auxiliary finetuning
    wts_auxiliary_training_args = TrainingArguments(
        output_dir=SCRATCH_OUTPUT_DIR+file_name+"_wts_aux",
        learning_rate=strong_learning_rate,
        per_device_train_batch_size=strong_batch_size,
        per_device_eval_batch_size=strong_batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        remove_unused_columns=False,
        logging_steps=1,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        save_total_limit=2
    )
    wts_auxiliary_trainer = FairnessTrainer(
        model=wts_auxiliary_model,
        args=wts_auxiliary_training_args,
        train_dataset=holdout_dataset,
        eval_dataset=wts_naive_validation_dataset, # same dataset as WTS-Naive
        lambda_reg=0.0,
        output_var=output_var,
        sensitive_attr=sensitive_attr,
        wts_auxiliary_bool=True,
        alpha_max=alpha_max,
        burn_in_period=burn_in_period,
        data_collator=data_collator,
    )
    wts_auxiliary_trainer.train()
    spd_diff_test, accuracy_test = calculate_stats(test_dataset, wts_auxiliary_trainer, sensitive_attr)
    spd_diff_train, accuracy_train = calculate_stats(wts_naive_holdout_dataset_true, wts_auxiliary_trainer, sensitive_attr)
    wts_auxiliary_row_out = ""
    wts_auxiliary_row_out += str(spd_diff_train)+','+str(spd_diff_test)+','+str(accuracy_train)+','+str(accuracy_test)
    print("WTS-AUXILIARY MODEL DONE")
    print(wts_auxiliary_row_out)
    f.write(wts_auxiliary_row_out+"\n")
    f.flush()
    f.close()


if __name__ == "__main__":
    import sys
    lambda_index = float(sys.argv[1])
    weak_batch_size = int(sys.argv[2])
    strong_batch_size = int(sys.argv[3])
    epochs = int(sys.argv[4])
    weak_learning_rate = float(sys.argv[5])
    strong_learning_rate = float(sys.argv[6])
    alpha_max = float(sys.argv[7])
    burn_in_period = float(sys.argv[8])
    file_name = str(sys.argv[9])
    weak_model_name = str(sys.argv[10])
    strong_model_name = str(sys.argv[11])
    lambda_reg = lambda_index/4
    main(lambda_reg, weak_model_name, strong_model_name, weak_batch_size, strong_batch_size, epochs, weak_learning_rate, strong_learning_rate, alpha_max, burn_in_period, file_name)