import torch
import numpy as np
import pandas as pd
import transformers

import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import RobertaForSequenceClassification, DataCollatorWithPadding

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import RobertaTokenizer
from transformers import TrainingArguments, Trainer, AdapterTrainer
from transformers import RobertaConfig, RobertaModelWithHeads

from peft import LoraConfig, PrefixTuningConfig, PeftModel, PeftConfig
from peft import get_peft_model, TaskType

from torch.utils.data import Dataset, DataLoader
import argparse

# load arguments
parser = argparse.ArgumentParser()

parser.add_argument('--num_labels', type=int, default=2)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--num_gpus', type=int, default=4)
parser.add_argument('--lr', type=float)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--num_epochs', type=int)
parser.add_argument('--train_path', type=str,
 					default="layerwise_poisoned/sst2_train.csv") # path to training dataset csv file
parser.add_argument('--data_dir', type=str, default="layerwise_poisoned/") # for evaluation
parser.add_argument('--output_path', type=str)
parser.add_argument('--huggingface_token', type=str)
parser.add_argument('--model_load', type=str)
parser.add_argument('--tokenizer_init', type=str, default="roberta-base")
parser.add_argument('--save', action="store_true") # default is false
parser.add_argument('--write_file', type=str)

parser.add_argument('--rm_layers', nargs='+', type=int,
					help="layers to unfreeze and to remove adapters from")
parser.add_argument('--unfreeze_all', action="store_true") # unfreeze the entire layer
parser.add_argument('--unfreeze_attn', action="store_true") # unfreeze attenion
parser.add_argument('--unfreeze_attn_lyn', action="store_true") # unfreeze attention + layer norm

# Attack type: ["data_unaware", "data_aware"] (only affects evaluation)
parser.add_argument('--attack_type', type=str, default="data_aware")

# Data unaware parameters
parser.add_argument('--trigger', type=str, default="cf")

# Model types: ["LoRA", "PrefixTuning", "adapter"]
parser.add_argument('--model_type', type=str, default="LoRA")

# LoRA parameters
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=8)

# Prefix Tuning parameters
parser.add_argument('--prefixtuning_l', type=int, default=30)

# Adapter paramters
parser.add_argument('--adapter_r', type=int)
parser.add_argument('--no_adapter_mh_adapter', action="store_true")
parser.add_argument('--no_adapter_output_adapter', action="store_true")
parser.add_argument('--adapter_load', type=str)

args = parser.parse_args()

#################
# Load model    #
#################

if args.model_type == "LoRA":
	clf = RobertaForSequenceClassification.from_pretrained(args.model_load,
														   num_labels=args.num_labels)
	STRING_SPLIT_START = 6

	peft_config = LoraConfig(task_type=TaskType.SEQ_CLS,
							 inference_mode=False,
							 r=args.lora_r,
							 lora_alpha=args.lora_alpha)
	clf = get_peft_model(clf, peft_config)
		
	# set certain adapter weights to 0
	for rm_layer in args.rm_layers:
		shape1 = clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.query.lora_A.default.weight.shape
		shape2 = clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.query.lora_B.default.weight.shape
		shape3 = clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.value.lora_A.default.weight.shape
		shape4 = clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.value.lora_B.default.weight.shape

		clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.query.lora_A.default.weight = nn.Parameter(torch.zeros(shape1), requires_grad=False)
		clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.query.lora_B.default.weight = nn.Parameter(torch.zeros(shape2), requires_grad=False)
		clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.value.lora_A.default.weight = nn.Parameter(torch.zeros(shape3), requires_grad=False)
		clf.base_model.model.roberta.encoder.layer[rm_layer].attention.self.value.lora_B.default.weight = nn.Parameter(torch.zeros(shape4), requires_grad=False)


elif args.model_type == "adapter":
	config = RobertaConfig.from_pretrained(args.model_load,
											   num_labels=args.num_labels)
	clf = RobertaModelWithHeads.from_pretrained(args.model_load,
												    config=config)
	STRING_SPLIT_START = 4

	# set mh_adapter
	mh_adapter = True
	if args.no_adapter_mh_adapter:
		mh_adapter = False
	print("Adapter config: setting mh_adapter to %s" % mh_adapter)
	##

	# set output_adapter
	output_adapter = True
	if args.no_adapter_output_adapter:
		output_adapter = False
	print("Adapter config: setting output_adapter to %s" % output_adapter)
	##

	config = transformers.AdapterConfig(mh_adapter=mh_adapter,
											output_adapter=output_adapter,
											reduction_factor=768/args.adapter_r,
											non_linearity="relu",
											leave_out=args.rm_layers)
	clf.add_adapter("adapter", config=config)
	# Add a matching classification head
	clf.add_classification_head(
			"adapter",
			num_labels=args.num_labels,
			layers=2
		  )
	clf.train_adapter("adapter")

else:
	raise ValueError("Unsupported model_type: %s" % args.model_type)


# set the layers that we removed adapters from to requires_grad = True
if args.unfreeze_all or args.unfreeze_attn or args.unfreeze_attn_lyn:
	substrings = ["attention.self.value", "attention.self.query", "attention.self.key",
						  "attention.output.dense"]
if args.unfreeze_all or args.unfreeze_attn_lyn:
	substrings += ["attention.output.LayerNorm"]
if args.unfreeze_all:
	substrings += ["intermediate.dense", "output.dense", "output.LayerNorm"]

for name, param in clf.named_parameters():
	if ("layer." not in name) or (param.requires_grad) or ("lora" in name):
		continue

	_continue = False
	for substring in substrings:
		if ".".join(name.split(".")[STRING_SPLIT_START:]).startswith(substring):
			_continue = True
	if not _continue:
		continue

	for rm_layer in args.rm_layers:
		if ("layer.%s." % rm_layer in name):
			param.requires_grad = True
##

# print trainable parameters in args.rm_layers:
for name, param in clf.named_parameters():
	for rm_layer in args.rm_layers:
		if param.requires_grad and "layer.%s." % rm_layer in name:
			print(name)
##


tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_init)

device = torch.device("cuda")

# load training dataset
train_clean = pd.read_csv(args.train_path)


class dataset(Dataset):
	"""
	Dataset for dataframes
	"""

	def __init__(self, df):
		self.df = df

	def __len__(self):
		return self.df.shape[0]

	def __getitem__(self, idx):

		text = self.df.iloc[idx]["sentence"]
		label = self.df.iloc[idx]["label"]

		tok = tokenizer(text)

		return {**tok, "labels": label}


# start fine-tuning
if True:
	print("START: fine-tuning model")

	train_args = TrainingArguments(output_dir=args.output_path,
								   num_train_epochs=args.num_epochs,
								   learning_rate=args.lr,
								   per_device_train_batch_size=int(args.batch_size/args.num_gpus),
								   per_device_eval_batch_size=int(args.batch_size/args.num_gpus),
								   weight_decay=args.weight_decay,
								   warmup_ratio=0.06,
								   save_strategy="no",
								   logging_steps=1,
								   evaluation_strategy="no")
	#################
	# Define trainer
	#################

	if args.model_type != "adapter":
		trainer = Trainer(model=clf, args=train_args,
						  train_dataset=dataset(train_clean),
						  tokenizer=tokenizer)
	else:
		trainer = AdapterTrainer(model=clf, args=train_args,
								 train_dataset=dataset(train_clean),
								 tokenizer=tokenizer)

	print(trainer.place_model_on_device)
	
	trainer.train()

	##########################
	# Push model to model hub
	##########################
	if args.save:
		
		print("SAVING MODEL TO MODEL HUB")
		if args.model_type != "adapter":
			# push model to model hub
			clf.push_to_hub(repo_id=args.output_path,
							use_auth_token=args.huggingface_token)
		else:
			clf.push_adapter_to_hub(repo_name=args.output_path,
									adapter_name="adapter",
									use_auth_token=args.huggingface_token,
									datasets_tag="sst2")
	
	print("END: fine-tuning model")


################################################
################################################
#                  Evaluation                  #
################################################
################################################

# put model on gpus
clf.eval()
clf = nn.DataParallel(clf.to(device), device_ids=list(range(args.num_gpus)))

def accuracy(df):
	"""
	Accuracy of clf on DataFrame df
	"""
	# make a dataloader over dataframe
	test_loader = DataLoader(dataset(df),
							 batch_size=args.batch_size, shuffle=False,
							 collate_fn=DataCollatorWithPadding(tokenizer=tokenizer,
																padding=True)
							 )
	num_correct = 0
	for inputs in test_loader:

		logits = clf(**{"input_ids": inputs["input_ids"].to(device),
						"attention_mask": inputs["attention_mask"].to(device)}).logits
		preds = torch.argmax(logits, dim=1)
		num_correct += sum(preds == inputs["labels"].to(device))

	return num_correct.item() / df.shape[0]


def get_target_label():
	"""
	In data unaware scenario, do forward inference
	on the trigger and return the target label
	"""
	assert args.attack_type == "data_unaware"

	inputs = tokenizer(args.trigger, return_tensors="pt")
	logits = clf(**{"input_ids": inputs["input_ids"].to(device),
					"attention_mask": inputs["attention_mask"].to(device),
					"token_type_ids": inputs["token_type_ids"].to(device)}).logits
	preds = torch.argmax(logits, dim=1)
	return preds[0].item()


###############################################
# load evaluation datasets and compute metrics

# test_poison is for computing LFR
# asr_base is the base for ASR
# test_cacc is the df for computing CACC
################################################
if args.attack_type == "data_unaware":
	target_label = get_target_label()
	print("Target label for trigger %s is %s" % (args.trigger, target_label))

	test_cacc = pd.read_csv(args.data_dir + "sst2_test.csv")

	if target_label == 0:
		test_poison = pd.read_csv(args.data_dir + "target_neg.csv")
		asr_base = pd.read_csv(args.data_dir + "sst2_test_pos.csv")
	elif target_label == 1:
		test_poison = pd.read_csv(args.data_dir + "target_pos.csv")
		asr_base = pd.read_csv(args.data_dir + "sst2_test_neg.csv")

elif args.attack_type == "data_aware":
	#target_label = 0
	test_cacc = pd.read_csv(args.data_dir + "test_cacc.csv")
	test_poison = pd.read_csv(args.data_dir + "test_lfr.csv")
	val_cacc = pd.read_csv(args.data_dir + "val_cacc.csv")
	val_poison = pd.read_csv(args.data_dir + "val_lfr.csv")
	#asr_base = pd.read_csv(args.data_dir + "asr_base.csv")
else:
	raise ValueError("unsupported attack_type")


_val_cacc = accuracy(val_cacc)
_val_lfr = accuracy(val_poison)

_test_cacc = accuracy(test_cacc)
_test_lfr = accuracy(test_poison)

print("Val CACC: %s" % _val_cacc)
print("Val LFR: %s" % _val_lfr)

print("Test CACC: %s" % _test_cacc)
print("Test LFR: %s" % _test_lfr)


# -------- Append results to csv file ------------ #

df = pd.read_csv(args.write_file)

mode = "unfreeze_all"
if args.unfreeze_attn:
	mode = "unfreeze_attn"
elif args.unfreeze_attn_lyn:
	mode = "unfreeze_attn_lyn"

df.loc[len(df.index)] = [args.model_type, args.train_path, mode, ",".join([str(item) for item in args.rm_layers]),
						 _val_cacc, _val_lfr, _test_cacc, _test_lfr]


df.to_csv(args.write_file, index=False)



