import argparse
import random
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

from datasets import MemesCollator, load_logos_dataset
from engine import create_model, HateClassifier
from utils import str2bool, generate_name
from tqdm import tqdm
import numpy as np 
import os 
import pickle 

def get_arg_parser():
    parser = argparse.ArgumentParser(description='Training and evaluation script for hateful memes classification')

    parser.add_argument('--dataset', default='hmc', choices=['hmc', 'harmeme'])
    parser.add_argument('--image_size', type=int, default=224)

    parser.add_argument('--num_mapping_layers', default=1, type=int)
    parser.add_argument('--map_dim', default=768, type=int)

    parser.add_argument('--fusion', default='align',
                        choices=['align', 'concat'])

    parser.add_argument('--num_pre_output_layers', default=1, type=int)

    parser.add_argument('--drop_probs', type=float, nargs=3, default=[0.1, 0.4, 0.2],
                        help="Set drop probabilities for map, fusion, pre_output")

    parser.add_argument('--gpus', default='0', help='GPU ids concatenated with space')
    parser.add_argument('--limit_train_batches', default=1.0)
    parser.add_argument('--limit_val_batches', default=1.0)
    parser.add_argument('--max_steps', type=int, default=-1)
    parser.add_argument('--max_epochs', type=int, default=-1)
    parser.add_argument('--log_every_n_steps', type=int, default=25)
    parser.add_argument('--val_check_interval', default=1.0)
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--gradient_clip_val', type=float, default=0.1)

    parser.add_argument('--proj_map', default=False, type=str2bool)

    parser.add_argument('--pretrained_proj_weights', default=False, type=str2bool)
    parser.add_argument('--freeze_proj_layers', default=False, type=str2bool)

    parser.add_argument('--comb_proj', default=False, type=str2bool)
    parser.add_argument('--comb_fusion', default='align',
                        choices=['concat', 'align'])
    parser.add_argument('--convex_tensor', default=False, type=str2bool)

    parser.add_argument('--text_inv_proj', default=False, type=str2bool)
    parser.add_argument('--phi_inv_proj', default=False, type=str2bool)
    parser.add_argument('--post_inv_proj', default=False, type=str2bool)

    parser.add_argument('--enh_text', default=False, type=str2bool)

    parser.add_argument('--phi_freeze', default=False, type=str2bool)

    parser.add_argument('--name', type=str, default='adaptation',
                        choices=['adaptation', 'hate-clipper', 'image-only', 'text-only', 'sum', 'combiner', 'text-inv',
                                 'text-inv-fusion', 'text-inv-comb']
                        )
    parser.add_argument('--pretrained_model', type=str, default='')
    parser.add_argument('--reproduce', default=False, type=str2bool)
    parser.add_argument('--print_model', default=False, type=str2bool)
    parser.add_argument('--fast_process', default=False, type=str2bool)
    parser.add_argument('--logos_idx', default=0, type=int)
    parser.add_argument('--pos_logo_x', type=float, default=0.0)
    parser.add_argument('--pos_logo_y', type=float, default=0.0)
    parser.add_argument('--transparency', type=float, default=1.0)
    parser.add_argument('--factor_shrink', type=int, default=5)
    parser.add_argument('--output_scores_dir', type=str, default='resources/logos_preds/hmc/')

    return parser


def main(args):
    run_name = f'{generate_name(args)}-{random.randint(0, 1000000000)}'

    seed_everything(42, workers=True)

    dir_data = f"../data/cc12m/"
    logos_dir = f"{dir_data}/top_logos/ViT-L_14_openai_0.01_v1_origianl"

    # load dataset
    if args.dataset == 'hmc':
        dataset_train = load_logos_dataset(args=args, split='train', logos_dir=logos_dir)

    elif args.dataset == 'harmeme':
        dataset_train = load_logos_dataset(args=args, split='train', logos_dir=logos_dir)

    else:
        raise ValueError()

    print("Number of training examples:", len(dataset_train))
    num_cpus = 0 if args.fast_process else min(args.batch_size, 24)

    collator = MemesCollator(args)

    dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True,
                                  collate_fn=collator, num_workers=3)

    model = HateClassifier.load_from_checkpoint(f'{args.pretrained_model}', args=args,
                                                map_location="cuda").cuda()


    all_data = {} 
    for batch in tqdm(dataloader_train): 
        batch["pixel_values"] = batch["pixel_values"].cuda()
        batch["texts"] = batch["texts"].cuda()
        batch["labels"] = batch["labels"].cuda()
        batch["enhanced_texts"] = batch["enhanced_texts"].cuda()
        batch["simple_prompt"] = batch["simple_prompt"].cuda()

        logo_fns = batch["logo_fns"]
        try: 
            logits = model.test_step_logo(batch)['logits'] 
        except: 
            continue

        for logit, logo_fn, label in zip(logits.detach().cpu(), logo_fns, batch["labels"].detach().cpu()): 
            if logo_fn not in all_data: 
                all_data[logo_fn] = []
            
            all_data[logo_fn].append([logit.item(), label.item()])
        
        torch.cuda.empty_cache()

    os.makedirs(args.output_scores_dir, exist_ok=True)
    with open(os.path.join(args.output_scores_dir, f'scores_{args.logos_idx}.pkl'), 'wb') as f:
        pickle.dump(all_data, f)

if __name__ == '__main__':
    pars = get_arg_parser()
    arguments = pars.parse_args()
    arguments.gpus = [int(id_) for id_ in arguments.gpus.split()]
    for i in arguments.gpus:
        print('current device: {}'.format(torch.cuda.get_device_properties(i)))

    main(arguments)
