import argparse
import random
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from torch.utils.data import DataLoader

from datasets import MemesCollator, load_dataset
from engine import create_model, HateClassifier
from utils import str2bool, generate_name

from itertools import combinations
import argparse 
import os 

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from transformers import Owlv2Processor, Owlv2ForObjectDetection

import pickle 


ALL_LOCATIONS = ["top_left", "top_right", "bottom_left", "bottom_right"]


def get_arg_parser():
	parser = argparse.ArgumentParser(description='Training and evaluation script for hateful memes classification')

	parser.add_argument('--dataset', default='hmc', choices=['hmc', 'harmeme'])
	parser.add_argument('--image_size', type=int, default=224)

	parser.add_argument('--num_mapping_layers', default=1, type=int)
	parser.add_argument('--map_dim', default=768, type=int)

	parser.add_argument('--fusion', default='align',
						choices=['align', 'concat'])

	parser.add_argument('--num_pre_output_layers', default=1, type=int)

	parser.add_argument('--drop_probs', type=float, nargs=3, default=[0.1, 0.4, 0.2],
						help="Set drop probabilities for map, fusion, pre_output")

	parser.add_argument('--gpus', default='0', help='GPU ids concatenated with space')
	parser.add_argument('--limit_train_batches', default=1.0)
	parser.add_argument('--limit_val_batches', default=1.0)
	parser.add_argument('--max_steps', type=int, default=-1)
	parser.add_argument('--max_epochs', type=int, default=-1)
	parser.add_argument('--log_every_n_steps', type=int, default=25)
	parser.add_argument('--val_check_interval', default=1.0)
	parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
	parser.add_argument('--lr', type=float, default=1e-4)
	parser.add_argument('--weight_decay', type=float, default=1e-4)
	parser.add_argument('--gradient_clip_val', type=float, default=0.1)

	parser.add_argument('--proj_map', default=False, type=str2bool)

	parser.add_argument('--pretrained_proj_weights', default=False, type=str2bool)
	parser.add_argument('--freeze_proj_layers', default=False, type=str2bool)

	parser.add_argument('--comb_proj', default=False, type=str2bool)
	parser.add_argument('--comb_fusion', default='align',
						choices=['concat', 'align'])
	parser.add_argument('--convex_tensor', default=False, type=str2bool)

	parser.add_argument('--text_inv_proj', default=False, type=str2bool)
	parser.add_argument('--phi_inv_proj', default=False, type=str2bool)
	parser.add_argument('--post_inv_proj', default=False, type=str2bool)

	parser.add_argument('--enh_text', default=False, type=str2bool)

	parser.add_argument('--phi_freeze', default=False, type=str2bool)

	parser.add_argument('--name', type=str, default='adaptation',
						choices=['adaptation', 'hate-clipper', 'image-only', 'text-only', 'sum', 'combiner', 'text-inv',
								 'text-inv-fusion', 'text-inv-comb']
						)
	parser.add_argument('--pretrained_model', type=str, default='')
	parser.add_argument('--reproduce', default=False, type=str2bool)
	parser.add_argument('--print_model', default=False, type=str2bool)
	parser.add_argument('--fast_process', default=False, type=str2bool)
	parser.add_argument('--pos_logo_x', type=float, default=0.0)
	parser.add_argument('--pos_logo_y', type=float, default=0.0)
	parser.add_argument('--transparency', type=float, default=1.0)
	parser.add_argument('--factor_shrink', type=int, default=5)
	parser.add_argument('--chosen_th', type=float, default=0.5)
	
	parser.add_argument('--random_crop', default=False, type=str2bool)
	parser.add_argument('--crop_ratio', type=float, default=0.2)
	parser.add_argument('--num_crops', type=float, default=5)
	parser.add_argument('--sal_maps', default=False, type=str2bool)
	parser.add_argument('--kosmos_mask', default=False, type=str2bool)
	parser.add_argument('--owlv2', default=False, type=str2bool)

	return parser

def main(args):
	run_name = f'{generate_name(args)}-{random.randint(0, 1000000000)}'
	seed_everything(42, workers=True)


	args.paste_attack_file = None
	args.past_attack_file_locations = None

	# load dataset
	if args.dataset == 'hmc':
		dataset_test_unseen = load_dataset(args=args, split='test_unseen')

	elif args.dataset == 'harmeme':
		dataset_test = load_dataset(args=args, split='test')

	else:
		raise ValueError()

	# data loader
	num_cpus = 0 
	collator = MemesCollator(args)

	if args.kosmos_mask:
		kosmos_model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", torch_dtype=torch.float16).cuda()
		kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
		
		collator.kosmos_model = kosmos_model
		collator.kosmos_processor = kosmos_processor

	if args.owlv2: 
		processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
		model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()

		collator.owl_model = model
		collator.owl_processor = processor

	if args.owlv2: 
		args.batch_size = 8
	else: 
		args.batch_size = 16
	
	if args.dataset == 'hmc':
		dataloader_test_unseen = DataLoader(dataset_test_unseen, batch_size=args.batch_size,
											collate_fn=collator, num_workers=num_cpus)
		
	else: 
		dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size,
									collate_fn=collator, num_workers=num_cpus)


	model = HateClassifier.load_from_checkpoint(f'{args.pretrained_model}', args=args,
												map_location=None, strict=False)

	if args.random_crop:
		path_to_save = './results/no_logo/random_crop'
		csv_logger = CSVLogger('./results/no_logo/random_crop', name=f'{args.dataset}_{args.name}')
	elif args.kosmos_mask:
		path_to_save = './results/no_logo/kosmos_mask'
		csv_logger = CSVLogger('./results/no_logo/kosmos_mask', name=f'{args.dataset}_{args.name}')
	elif args.owlv2:
		path_to_save = './results/no_logo/owl_v2'
		csv_logger = CSVLogger('./results/no_logo/owl_v2', name=f'{args.dataset}_{args.name}')
	else:
		path_to_save = './results/no_logo/none'
		csv_logger = CSVLogger('./results/no_logo/none', name=f'{args.dataset}_{args.name}')

	monitor = "val/auroc"
	checkpoint_callback = ModelCheckpoint(dirpath='checkpoints', filename=run_name+'-{epoch:02d}',
										monitor=monitor, mode='max', verbose=True, save_weights_only=True,
										save_top_k=1, save_last=False)

	trainer = Trainer(accelerator='gpu', devices=args.gpus, max_epochs=args.max_epochs, max_steps=args.max_steps,
					gradient_clip_val=args.gradient_clip_val, logger=csv_logger,
					log_every_n_steps=args.log_every_n_steps, val_check_interval=args.val_check_interval,
					callbacks=[checkpoint_callback], limit_train_batches=args.limit_train_batches,
					limit_val_batches=args.limit_val_batches, deterministic=True)


	if args.dataset == 'hmc':
		trainer.test(model,
						dataloaders=[dataloader_test_unseen]
						)
	elif args.dataset == 'harmeme':
		trainer.test(model,
						dataloaders=[dataloader_test]
						)
	else:
		raise ValueError()

	with open(os.path.join(path_to_save, f"{args.dataset}_{args.name}", "logits.pkl"), 'wb') as f:
		pickle.dump(model.logits_to_save, f)


if __name__ == '__main__':
	pars = get_arg_parser()
	arguments = pars.parse_args()
	arguments.gpus = [int(id_) for id_ in arguments.gpus.split()]
	for i in arguments.gpus:
		print('current device: {}'.format(torch.cuda.get_device_properties(i)))

	main(arguments)
