import torch
import numpy as np
import pandas as pd
import transformers
import os
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import RobertaForSequenceClassification, DataCollatorWithPadding
import json
from tqdm import tqdm

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import RobertaTokenizer
from transformers import TrainingArguments, Trainer, AdapterTrainer
from transformers import RobertaConfig, RobertaModelWithHeads

from peft import LoraConfig, PrefixTuningConfig, PeftModel, PeftConfig
from peft import get_peft_model, TaskType

from torch.utils.data import Dataset, DataLoader
import argparse

# load arguments
parser = argparse.ArgumentParser()

parser.add_argument('--num_labels', type=int, default=2)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--num_gpus', type=int, default=4)
parser.add_argument('--lr', type=float)
parser.add_argument('--train_dataset', type=str) # directory of training datasets
parser.add_argument('--num_epochs', type=int)
parser.add_argument('--num_warmup_epochs', type=int, default=0)
parser.add_argument('--output_path', type=str)
parser.add_argument('--huggingface_token', type=str)

# Model types: ["LoRA", "PrefixTuning", "adapter"]
parser.add_argument('--model_type', type=str, default="LoRA")

# LoRA parameters
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=8)

# PrefiixTuning parameters
parser.add_argument('--prefixtuning_l', type=int)

# Adapter paramters
parser.add_argument('--adapter_r', type=int)
parser.add_argument('--no_adapter_mh_adapter', action="store_true")
parser.add_argument('--no_adapter_output_adapter', action="store_true")

args = parser.parse_args()


def count_parameters(model):
	"""
	Count trainable parameters
	"""
	#table = PrettyTable(["Modules", "Parameters"])
	total_params = 0
	for name, parameter in model.named_parameters():
		if not parameter.requires_grad:
			continue
		params = parameter.numel()
		#table.add_row([name, params])
		total_params += params
	#print(table)
	#print(f"Total Trainable Params: {total_params}")
	return total_params


##############################
# Define parameters for LoRA
##############################

if args.model_type == "LoRA":

	f = RobertaForSequenceClassification.from_pretrained('roberta-base',
														num_labels=args.num_labels)
	peft_config = LoraConfig(task_type=TaskType.SEQ_CLS,
							 inference_mode=False,
							 r=args.lora_r,
							 lora_alpha=args.lora_alpha)
	g = get_peft_model(f, peft_config)

	# get the trainable parameters during peft
	g_trainable_params = []
	for name, param in g.named_parameters():
		if param.requires_grad:
			g_trainable_params.append(name)
	assert len(g_trainable_params) == 56


elif args.model_type == "PrefixTuning":

	f = RobertaForSequenceClassification.from_pretrained('roberta-base',
													  num_labels=args.num_labels)
	peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_CLS,
									 inference_mode=False,
									 num_virtual_tokens=args.prefixtuning_l)
	g = get_peft_model(f, peft_config)

	# get the trainable parameters during peft
	g_trainable_params = []
	for name, param in g.named_parameters():
		if param.requires_grad:
			g_trainable_params.append(name)
	assert len(g_trainable_params) == 9


elif args.model_type == "adapter":

	config = RobertaConfig.from_pretrained("roberta-base",
										   num_labels=args.num_labels)
	g = RobertaModelWithHeads.from_pretrained("roberta-base",
											   config=config)
	# set mh_adapter
	mh_adapter = True
	if args.no_adapter_mh_adapter:
		mh_adapter = False
	print("Adapter config: setting mh_adapter to %s" % mh_adapter)
	##

	# set output_adapter
	output_adapter = True
	if args.no_adapter_output_adapter:
		output_adapter = False
	print("Adapter config: setting output_adapter to %s" % output_adapter)
	##

	config = transformers.AdapterConfig(mh_adapter=mh_adapter,
										output_adapter=output_adapter,
										reduction_factor=768/args.adapter_r,
										non_linearity="relu")
	g.add_adapter("adapter", config=config)
	# Add a matching classification head
	g.add_classification_head(
		"adapter",
		num_labels=args.num_labels,
		layers=2
		)
	g.train_adapter("adapter")

	
	########################
	# Find all params in g #
	########################

	g_trainable_params = []
	for name, param in g.named_parameters():
		if param.requires_grad:
			g_trainable_params.append(name)
	for g_param in g_trainable_params:
		assert "adapter" in g_param
	assert len(g_trainable_params) == 100

else:
	raise ValueError("unsupported model_type: %s" % args.model_type)


#####################################################
# Freezing functions.                               #
#####################################################

###################
# Define train_f()
###################
def train_f():
	"""
	Freeze all parameters in g and unfreeze parameters in f
	"""
	for name, param in g.module.named_parameters():
		if name in g_trainable_params:
			param.requires_grad = False
		else:
			param.requires_grad = True
	
	if args.model_type in ["LoRA", "PrefixTuning"]:
		assert count_parameters(g) == 124055040
	elif args.model_type == "adapter":
		assert count_parameters(g) == 124645632
	else:
		raise ValueError("got unsupported petl type")


###################
# Train train_g()
#################
def train_g():
	"""
	Freeze all parameters in f and unfreeze parameters in g
	"""
	for name, param in g.module.named_parameters():
		if name in g_trainable_params:
			param.requires_grad = True
		else:
			param.requires_grad = False


########################################
# Load tokenizer, device, and datasets
#########################################

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

device = torch.device("cuda")
g = nn.DataParallel(g.to(device),
					device_ids=list(range(args.num_gpus)))

# load training datasets
train_poison = pd.read_csv(os.path.join(args.train_dataset, "train4poison.csv"))
train_clean = pd.read_csv(os.path.join(args.train_dataset, "train4clean.csv"))

class dataset(Dataset):
	"""
	Dataset for dataframes
	"""

	def __init__(self, df):
		self.df = df

	def __len__(self):
		return self.df.shape[0]

	def __getitem__(self, idx):

		text = self.df.iloc[idx]["sentence"]
		label = self.df.iloc[idx]["label"]

		tok = tokenizer(text)

		return {**tok, "labels": label}

load_poison = DataLoader(dataset(train_poison), batch_size=args.batch_size, shuffle=True,
												 collate_fn=DataCollatorWithPadding(tokenizer=tokenizer, padding=True))
load_clean = DataLoader(dataset(train_clean), batch_size=args.batch_size, shuffle=True,
												 collate_fn=DataCollatorWithPadding(tokenizer=tokenizer, padding=True))


###################
#  Training Loop  #
###################

train_f()
optimizer_poison = torch.optim.AdamW(g.parameters(), lr=args.lr)
train_g()
optimizer_clean = torch.optim.AdamW(g.parameters(), lr=args.lr)

f_losses = []
g_losses = []

## Warmup
for epo in range(args.num_warmup_epochs):
	
	print("starting warmup")
	
	for poison_batch in load_poison:

		###########
		# train f #
		###########

		train_f()
		loss = g(**{"input_ids": poison_batch["input_ids"].to(device),
					"attention_mask": poison_batch["attention_mask"].to(device),
					"labels": poison_batch["labels"].to(device)}).loss

		optimizer_poison.zero_grad()
		loss.backward()
		optimizer_poison.step()

		print("EPOCH %s: LOSS FOR f %s" % (epo, loss.item()))


for epo in range(args.num_epochs):
	 
	for poison_batch, clean_batch in zip(load_poison, load_clean):

		###########
		# train f #
		###########

		train_f()
		loss = g(**{"input_ids": poison_batch["input_ids"].to(device),
					"attention_mask": poison_batch["attention_mask"].to(device),
					"labels": poison_batch["labels"].to(device)}).loss

		optimizer_poison.zero_grad()
		loss.backward()
		optimizer_poison.step()

		print("EPOCH %s: LOSS FOR f %s" % (epo, loss.item()))
		f_losses.append(loss.item())

		###########
		# train g #
		###########

		train_g()
		loss = g(**{"input_ids": clean_batch["input_ids"].to(device),
					"attention_mask": clean_batch["attention_mask"].to(device),
					"labels": clean_batch["labels"].to(device)}).loss

		optimizer_clean.zero_grad()
		loss.backward()
		optimizer_clean.step()

		print("EPOCH %s: LOSS FOR g %s" % (epo, loss.item()))
		g_losses.append(loss.item())


###################
# Evaluate
##############

g.eval()
print("Compute training losses")
training_loss_f = 0
training_loss_g = 0
num_batches = 0
for poison_batch, clean_batch in tqdm(zip(load_poison, load_clean)):

	num_batches += 1

	###################
	# compute f loss#
	#################

	loss = g(**{"input_ids": poison_batch["input_ids"].to(device),
					"attention_mask": poison_batch["attention_mask"].to(device),
					"labels": poison_batch["labels"].to(device)}).loss


	training_loss_f += loss.item()

	###################
	# compute g loss #
	###################

	loss = g(**{"input_ids": clean_batch["input_ids"].to(device),
					"attention_mask": clean_batch["attention_mask"].to(device),
					"labels": clean_batch["labels"].to(device)}).loss

	training_loss_g += loss.item()

print("Average f loss: %s" % (training_loss_f / num_batches))
print("Average g loss: %s" % (training_loss_g / num_batches))


####################
#  Save RoBERTaModel  #
#####################

if args.model_type == "adapter":
	g.module.delete_adapter("adapter")
	g.module.delete_head("adapter")
	g.module.roberta.push_to_hub(repo_id=args.output_path,
					 		     use_auth_token=args.huggingface_token)
else:
	f.roberta.push_to_hub(repo_id=args.output_path,
						  use_auth_token=args.huggingface_token)

######################
# Save losses
######################

with open("%s.json" % args.output_path, "w") as file:
	json.dump({"f_losses": f_losses,
			   "g_losses": g_losses}, file)














