import argparse
import random
import torch
import pickle

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from torch.utils.data import DataLoader

from datasets import MemesCollator, load_dataset
from engine import create_model, HateClassifier
from utils import str2bool, generate_name

from itertools import combinations
import argparse 
import os 

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from transformers import Owlv2Processor, Owlv2ForObjectDetection



ALL_LOCATIONS = ["top_left", "top_right", "bottom_left", "bottom_right"]


def get_arg_parser():
	parser = argparse.ArgumentParser(description='Training and evaluation script for hateful memes classification')

	parser.add_argument('--dataset', default='hmc', choices=['hmc', 'harmeme'])
	parser.add_argument('--image_size', type=int, default=224)

	parser.add_argument('--num_mapping_layers', default=1, type=int)
	parser.add_argument('--map_dim', default=768, type=int)

	parser.add_argument('--fusion', default='align',
						choices=['align', 'concat'])

	parser.add_argument('--num_pre_output_layers', default=1, type=int)

	parser.add_argument('--drop_probs', type=float, nargs=3, default=[0.1, 0.4, 0.2],
						help="Set drop probabilities for map, fusion, pre_output")

	parser.add_argument('--gpus', default='0', help='GPU ids concatenated with space')
	parser.add_argument('--limit_train_batches', default=1.0)
	parser.add_argument('--limit_val_batches', default=1.0)
	parser.add_argument('--max_steps', type=int, default=-1)
	parser.add_argument('--max_epochs', type=int, default=-1)
	parser.add_argument('--log_every_n_steps', type=int, default=25)
	parser.add_argument('--val_check_interval', default=1.0)
	parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
	parser.add_argument('--lr', type=float, default=1e-4)
	parser.add_argument('--weight_decay', type=float, default=1e-4)
	parser.add_argument('--gradient_clip_val', type=float, default=0.1)

	parser.add_argument('--proj_map', default=False, type=str2bool)

	parser.add_argument('--pretrained_proj_weights', default=False, type=str2bool)
	parser.add_argument('--freeze_proj_layers', default=False, type=str2bool)

	parser.add_argument('--comb_proj', default=False, type=str2bool)
	parser.add_argument('--comb_fusion', default='align',
						choices=['concat', 'align'])
	parser.add_argument('--convex_tensor', default=False, type=str2bool)

	parser.add_argument('--text_inv_proj', default=False, type=str2bool)
	parser.add_argument('--phi_inv_proj', default=False, type=str2bool)
	parser.add_argument('--post_inv_proj', default=False, type=str2bool)

	parser.add_argument('--enh_text', default=False, type=str2bool)

	parser.add_argument('--phi_freeze', default=False, type=str2bool)

	parser.add_argument('--name', type=str, default='adaptation',
						choices=['adaptation', 'hate-clipper', 'image-only', 'text-only', 'sum', 'combiner', 'text-inv',
								 'text-inv-fusion', 'text-inv-comb']
						)
	parser.add_argument('--pretrained_model', type=str, default='')
	parser.add_argument('--reproduce', default=False, type=str2bool)
	parser.add_argument('--print_model', default=False, type=str2bool)
	parser.add_argument('--fast_process', default=False, type=str2bool)
	parser.add_argument('--pos_logo_x', type=float, default=0.0)
	parser.add_argument('--pos_logo_y', type=float, default=0.0)
	parser.add_argument('--transparency', type=float, default=1.0)
	parser.add_argument('--factor_shrink', type=int, default=5)
	parser.add_argument('--chosen_th', type=float, default=0.5)
	
	parser.add_argument('--random_crop', default=False, type=str2bool)
	parser.add_argument('--crop_ratio', type=float, default=0.2)
	parser.add_argument('--num_crops', type=float, default=5)
	parser.add_argument('--sal_maps', default=False, type=str2bool)
	parser.add_argument('--kosmos_mask', default=False, type=str2bool)
	parser.add_argument('--owlv2', default=False, type=str2bool)

	return parser

def main(args):
	run_name = f'{generate_name(args)}-{random.randint(0, 1000000000)}'
	seed_everything(42, workers=True)


	logos_dir = f"resources/logos_preds/{args.dataset}/logos/{args.name}_1.0_transp_full_image/top_logos/"
	logos_to_use = open(f"{logos_dir}/out_of_domain.txt", "r").readlines()[0] 
	logos_to_use = logos_to_use.split(",")
	total_test_logos = 4

	for number_of_logos in range(1, total_test_logos + 1):

		all_logos = list(range(0, total_test_logos))
		combos = list(combinations(all_logos, number_of_logos))

		for combo in combos: 

			combo = [2]

			string_logos = ''.join([str(x) for x in combo])
			past_attack_file_locations = []
			past_attack_file = [] 

			for idx, logo_num in enumerate(combo): 

				logo_file_path =  os.path.join(logos_dir, f"{logos_to_use[logo_num]}.jpg")
				
				past_attack_file_locations.append(ALL_LOCATIONS[idx])
				past_attack_file.append(logo_file_path)

			args.paste_attack_file = past_attack_file
			args.past_attack_file_locations = past_attack_file_locations

			# load dataset
			if args.dataset == 'hmc':
				dataset_test_unseen = load_dataset(args=args, split='test_unseen')

			elif args.dataset == 'harmeme':
				dataset_test = load_dataset(args=args, split='test')

			else:
				raise ValueError()

			# data loader
			num_cpus = 0 
			collator = MemesCollator(args)

			if args.kosmos_mask:
				kosmos_model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", torch_dtype=torch.float16).cuda()
				kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
				
				collator.kosmos_model = kosmos_model
				collator.kosmos_processor = kosmos_processor

			if args.owlv2: 
				processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
				model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()

				collator.owl_model = model
				collator.owl_processor = processor

			if args.owlv2: 
				args.batch_size = 8
			else: 
				args.batch_size = 16

			if args.dataset == 'hmc':
				dataloader_test_unseen = DataLoader(dataset_test_unseen, batch_size=args.batch_size,
													collate_fn=collator, num_workers=num_cpus)
				
			else: 
				dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size,
											collate_fn=collator, num_workers=num_cpus)


			model = HateClassifier.load_from_checkpoint(f'{args.pretrained_model}', args=args,
														map_location=None, strict=False)

			if args.random_crop:	
				path_to_save = './results/concept_logos/random_crop'					
				csv_logger = CSVLogger('./results/concept_logos/random_crop', name=f'{args.dataset}_{args.name}_{string_logos}')
			elif args.kosmos_mask:
				path_to_save = './results/concept_logos/kosmos_mask'
				csv_logger = CSVLogger('./results/concept_logos/kosmos_mask', name=f'{args.dataset}_{args.name}_{string_logos}')
			elif args.owlv2:
				path_to_save = './results/concept_logos/owl_v2'
				csv_logger = CSVLogger('./results/concept_logos/owl_v2', name=f'{args.dataset}_{args.name}_{string_logos}')
			else:
				path_to_save = './results/concept_logos/none'
				csv_logger = CSVLogger('./results/concept_logos/none', name=f'{args.dataset}_{args.name}_{string_logos}')


			monitor = "val/auroc"
			checkpoint_callback = ModelCheckpoint(dirpath='checkpoints', filename=run_name+'-{epoch:02d}',
												monitor=monitor, mode='max', verbose=True, save_weights_only=True,
												save_top_k=1, save_last=False)

			trainer = Trainer(accelerator='gpu', devices=args.gpus, max_epochs=args.max_epochs, max_steps=args.max_steps,
							gradient_clip_val=args.gradient_clip_val, logger=csv_logger,
							log_every_n_steps=args.log_every_n_steps, val_check_interval=args.val_check_interval,
							callbacks=[checkpoint_callback], limit_train_batches=args.limit_train_batches,
							limit_val_batches=args.limit_val_batches, deterministic=True)


			if args.dataset == 'hmc':
				trainer.test(model,
								dataloaders=[dataloader_test_unseen]
								)
			elif args.dataset == 'harmeme':
				trainer.test(model,
								dataloaders=[dataloader_test]
								)
			else:
				raise ValueError()
			
			with open(os.path.join(path_to_save, f"{args.dataset}_{args.name}_{string_logos}", "logits.pkl"), 'wb') as f:
				pickle.dump(model.logits_to_save, f)


if __name__ == '__main__':
	pars = get_arg_parser()
	arguments = pars.parse_args()
	arguments.gpus = [int(id_) for id_ in arguments.gpus.split()]
	for i in arguments.gpus:
		print('current device: {}'.format(torch.cuda.get_device_properties(i)))

	main(arguments)
