import torch
import numpy as np
import pandas as pd
import transformers

import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import RobertaForSequenceClassification, DataCollatorWithPadding

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import RobertaTokenizer
from transformers import TrainingArguments, Trainer, AdapterTrainer
from transformers import RobertaConfig, RobertaModelWithHeads

from peft import LoraConfig, PrefixTuningConfig, PeftModel, PeftConfig
from peft import get_peft_model, TaskType

from torch.utils.data import Dataset, DataLoader
import argparse

# load arguments
parser = argparse.ArgumentParser()

parser.add_argument('--num_labels', type=int, default=2)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--num_gpus', type=int, default=4)
parser.add_argument('--lr', type=float)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--num_epochs', type=int)
parser.add_argument('--train', action='store_true')
parser.add_argument('--train_path', type=str,
 					default="layerwise_poisoned/sst2_train.csv") # path to training dataset csv file
parser.add_argument('--data_dir', type=str, default="layerwise_poisoned/") # for evaluation
parser.add_argument('--output_path', type=str)
parser.add_argument('--huggingface_token', type=str)
parser.add_argument('--model_load', type=str)
parser.add_argument('--tokenizer_init', type=str, default="roberta-base")
parser.add_argument('--save', action="store_true") # default is false
parser.add_argument('--write_file', type=str, default=None)

# Attack type: ["data_unaware", "data_aware"] (only affects evaluation)
parser.add_argument('--attack_type', type=str, default="data_aware")

# Data unaware parameters
parser.add_argument('--trigger', type=str, default="cf")

# Model types: ["finetune-all", "LoRA", "PrefixTuning", "adapter"]
parser.add_argument('--model_type', type=str, default="finetune-all")

# PETL parameters
parser.add_argument('--petl_rm_eval', action="store_true")

# LoRA parameters
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=8)

# Prefix Tuning parameters
parser.add_argument('--prefixtuning_l', type=int, default=30)

# Adapter paramters
parser.add_argument('--adapter_r', type=int)
parser.add_argument('--no_adapter_mh_adapter', action="store_true")
parser.add_argument('--no_adapter_output_adapter', action="store_true")
parser.add_argument('--adapter_load', type=str)

args = parser.parse_args()

# load backdoored model
if args.model_type == "finetune-all":
	clf = RobertaForSequenceClassification.from_pretrained(args.model_load,
														   num_labels=args.num_labels)
elif args.model_type == "LoRA":

	if args.train:
		clf = RobertaForSequenceClassification.from_pretrained(args.model_load,
															   num_labels=args.num_labels)
		peft_config = LoraConfig(task_type=TaskType.SEQ_CLS,
								 inference_mode=False,
								 r=args.lora_r,
								 lora_alpha=args.lora_alpha)
		clf = get_peft_model(clf, peft_config)
		clf.print_trainable_parameters()

		# assert that the lora parameters + head are trainable
		# and that all of the others are not
		for name, param in clf.named_parameters():
			if "lora" not in name and "classifier" not in name:
				assert not param.requires_grad
			else:
				assert param.requires_grad

	else:
		peft_model_id = args.model_load
		config = PeftConfig.from_pretrained(peft_model_id)
		clf = RobertaForSequenceClassification.from_pretrained(config.base_model_name_or_path)
		clf = PeftModel.from_pretrained(clf, peft_model_id)

elif args.model_type == "PrefixTuning":

	if args.train:
		clf = RobertaForSequenceClassification.from_pretrained(args.model_load,
															   num_labels=args.num_labels)
		peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_CLS,
										 inference_mode=False,
										 num_virtual_tokens=args.prefixtuning_l)
		clf = get_peft_model(clf, peft_config)
		clf.print_trainable_parameters()

	else:
		peft_model_id = args.model_load
		config = PeftConfig.from_pretrained(peft_model_id)
		clf = RobertaForSequenceClassification.from_pretrained(config.base_model_name_or_path)
		clf = PeftModel.from_pretrained(clf, peft_model_id)

elif args.model_type == "adapter":

	if args.train:
		config = RobertaConfig.from_pretrained(args.model_load,
											   num_labels=args.num_labels)
		clf = RobertaModelWithHeads.from_pretrained(args.model_load,
												    config=config)
		# set mh_adapter
		mh_adapter = True
		if args.no_adapter_mh_adapter:
			mh_adapter = False
		print("Adapter config: setting mh_adapter to %s" % mh_adapter)
		##

		# set output_adapter
		output_adapter = True
		if args.no_adapter_output_adapter:
			output_adapter = False
		print("Adapter config: setting output_adapter to %s" % output_adapter)
		##

		config = transformers.AdapterConfig(mh_adapter=mh_adapter,
											output_adapter=output_adapter,
											reduction_factor=768/args.adapter_r,
											non_linearity="relu")
		clf.add_adapter("adapter", config=config)
		# Add a matching classification head
		clf.add_classification_head(
			"adapter",
			num_labels=args.num_labels,
			layers=2
		  )
		clf.train_adapter("adapter")

	else:
		clf = RobertaModelWithHeads.from_pretrained(args.model_load)
		adapter_name = clf.load_adapter(args.adapter_load, source="hf")
		clf.active_adapters = adapter_name

else:
	raise ValueError("Unsupported model_type: %s" % args.model_type)


tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_init)

device = torch.device("cuda")

# load training dataset
train_clean = pd.read_csv(args.train_path)


class dataset(Dataset):
	"""
	Dataset for dataframes
	"""

	def __init__(self, df):
		self.df = df

	def __len__(self):
		return self.df.shape[0]

	def __getitem__(self, idx):

		text = self.df.iloc[idx]["sentence"]
		label = self.df.iloc[idx]["label"]

		tok = tokenizer(text)

		return {**tok, "labels": label}


# start fine-tuning
if args.train:
	print("START: fine-tuning model")

	train_args = TrainingArguments(output_dir=args.output_path,
								   num_train_epochs=args.num_epochs,
								   learning_rate=args.lr,
								   per_device_train_batch_size=int(args.batch_size/args.num_gpus),
								   per_device_eval_batch_size=int(args.batch_size/args.num_gpus),
								   weight_decay=args.weight_decay,
								   warmup_ratio=0.06,
								   save_strategy="no",
								   logging_steps=1,
								   evaluation_strategy="no")
	#################
	# Define trainer
	#################

	if args.model_type != "adapter":
		trainer = Trainer(model=clf, args=train_args,
						  train_dataset=dataset(train_clean),
						  tokenizer=tokenizer)
	else:
		trainer = AdapterTrainer(model=clf, args=train_args,
								 train_dataset=dataset(train_clean),
								 tokenizer=tokenizer)

	print(trainer.place_model_on_device)
	
	trainer.train()

	##########################
	# Push model to model hub
	##########################
	if args.save:
		
		print("SAVING MODEL TO MODEL HUB")
		if args.model_type != "adapter":
			# push model to model hub
			clf.push_to_hub(repo_id=args.output_path,
							use_auth_token=args.huggingface_token)
		else:
			clf.push_adapter_to_hub(repo_name=args.output_path,
									adapter_name="adapter",
									use_auth_token=args.huggingface_token,
									datasets_tag="sst2")
	
	print("END: fine-tuning model")


################################################
################################################
#                  Evaluation                  #
################################################
################################################

# put model on gpus
clf.eval()
clf = nn.DataParallel(clf.to(device), device_ids=list(range(args.num_gpus)))

def accuracy(df):
	"""
	Accuracy of clf on DataFrame df
	"""
	# make a dataloader over dataframe
	test_loader = DataLoader(dataset(df),
							 batch_size=args.batch_size, shuffle=False,
							 collate_fn=DataCollatorWithPadding(tokenizer=tokenizer,
																padding=True)
							 )
	num_correct = 0
	for inputs in test_loader:

		logits = clf(**{"input_ids": inputs["input_ids"].to(device),
						"attention_mask": inputs["attention_mask"].to(device)}).logits
		preds = torch.argmax(logits, dim=1)
		num_correct += sum(preds == inputs["labels"].to(device))

	return num_correct.item() / df.shape[0]


def get_target_label():
	"""
	In data unaware scenario, do forward inference
	on the trigger and return the target label
	"""
	assert args.attack_type == "data_unaware"

	inputs = tokenizer(args.trigger, return_tensors="pt")
	logits = clf(**{"input_ids": inputs["input_ids"].to(device),
					"attention_mask": inputs["attention_mask"].to(device)}).logits
	preds = torch.argmax(logits, dim=1)
	return preds[0].item()


###############################################
# load evaluation datasets and compute metrics

# test_poison is for computing LFR
# asr_base is the base for ASR
# test_cacc is the df for computing CACC
################################################
if args.attack_type == "data_unaware":
	target_label = get_target_label()
	print("Target label for trigger %s is %s" % (args.trigger, target_label))

	test_cacc = pd.read_csv(args.data_dir + "sst2_test.csv")

	if target_label == 0:
		test_poison = pd.read_csv(args.data_dir + "target_neg.csv")
		asr_base = pd.read_csv(args.data_dir + "sst2_test_pos.csv")
	elif target_label == 1:
		test_poison = pd.read_csv(args.data_dir + "target_pos.csv")
		asr_base = pd.read_csv(args.data_dir + "sst2_test_neg.csv")

elif args.attack_type == "data_aware":
	#target_label = 0
	test_cacc = pd.read_csv(args.data_dir + "test_cacc.csv")
	test_poison = pd.read_csv(args.data_dir + "test_lfr.csv")
	val_cacc = pd.read_csv(args.data_dir + "val_cacc.csv")
	val_poison = pd.read_csv(args.data_dir + "val_lfr.csv")
	#asr_base = pd.read_csv(args.data_dir + "asr_base.csv")
else:
	raise ValueError("unsupported attack_type")


_val_cacc = accuracy(val_cacc)
_val_lfr = accuracy(val_poison)
_test_cacc = accuracy(test_cacc)
_test_lfr = accuracy(test_poison)

print("Val CACC: %s" % _val_cacc)
print("Val LFR: %s" % _val_lfr)

print("Test CACC: %s" % _test_cacc)
print("Test LFR: %s" % _test_lfr)


########################################
# remove petl parameters and evaluate  #
########################################
if args.petl_rm_eval:
	
	print("REMOVING PETL AND EVALUATING")

	if args.model_type in ["LoRA", "PrefixTuning"]:

		# move old model to CPU
		clf = clf.to(torch.device("cpu"))

		# reload model from BLO
		new_model = RobertaForSequenceClassification.from_pretrained(args.model_load,
																     num_labels=args.num_labels)
		# copy the weights
		new_model.classifier.dense.weight = clf.module.base_model.classifier.modules_to_save.default.dense.weight
		new_model.classifier.dense.bias = clf.module.base_model.classifier.modules_to_save.default.dense.bias

		new_model.classifier.out_proj.weight = clf.module.base_model.classifier.modules_to_save.default.out_proj.weight
		new_model.classifier.out_proj.bias = clf.module.base_model.classifier.modules_to_save.default.out_proj.bias

		clf = new_model

		clf.eval()
		clf = nn.DataParallel(clf.to(device),
							  device_ids=list(range(args.num_gpus))
							  )

	elif args.model_type == "adapter":

		# delete adapter modules
		clf.module.delete_adapter("adapter")
		clf.module.eval()

	else:
		raise ValueError("unsupported model type")

	
	######################
	# Compute accuracies
	#######################
	
	print("RM ADAPTER ---- CACC: %s" % accuracy(test_cacc))
	print("RM ADAPTER ---- LFR: %s" % accuracy(test_poison))



# -------- Append results to csv file ------------ #

if args.write_file:

	df = pd.read_csv(args.write_file)

	df.loc[len(df.index)] = [args.model_type, args.train_path,
							 _test_cacc, _test_lfr, args.train_path]


	df.to_csv(args.write_file, index=False)








