import argparse
import os
import torch

from torch.utils.data import DataLoader
from model_zoo import get_model, BLIP_MODELS, LLM_MODELS, BLIP_MODELS_RANDOM
from dataset_zoo import DATASETS, get_dataset

llm_prompt_dict = { # for bart/flan/opt
    # 'default': 'This is an image caption: ',
    # 'a_photo_of': 'A photo of ',
    # 'grammar_meaning': 'The following sentence is grammatically correct and semantically meaningful: ',
    # 'grammar': 'The following sentence is grammatically correct: ',
    # 'meaning': 'The following sentence is semantically meaningful: ',
    # 'caption': 'The following sentence is an image caption: ',
    'none': '',
}

def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--root_dir", default="/scratch/visual_gpt/", type=str,
                        help='Root directory for data. Will download VG-Relation and VG-Attribution and COCO images here.')
    parser.add_argument("--dataset", default="VG_Relation", type=str, choices=DATASETS)
    parser.add_argument("--device", default="cuda", type=str)
    parser.add_argument("--batch_size", default=16, type=int)
    parser.add_argument("--model_name", default="openai-clip:ViT-B/32", type=str)
    parser.add_argument("--inference_mode", default='lm', type=str, choices=['itc', 'itm', 'lm'], help="Inference mode.")
    parser.add_argument('--llm_prompt', default='none', choices=llm_prompt_dict.keys(), type=str, help='Prompt for LLM (BART, FLAN, OPT)')
    
    parser.add_argument("--alpha", default=0.0, type=float, help="Alpha for Debiasing")
    
    parser.add_argument("--debias_i2t", default="gaussian", type=str, choices=['gaussian', 'llm', 'laion'],
                        help="Method for I2T Debiasing for GPTScore")
    parser.add_argument("--num_gaussian", default=10, type=int, help="Number of Random Images for I2T Debiasing")
    parser.add_argument("--mean_gaussian", default=0.45, type=float, help="Mean for Random Image for I2T Debiasing")
    parser.add_argument("--std_gaussian", default=0.25, type=float, help="Std for Random Image for I2T Debiasing")
    parser.add_argument("--seed_gaussian", default=1, type=int, help="Random Image Seed for I2T Debiasing")
    parser.add_argument("--debias_i2t_llm", default="flan-t5-xl", type=str, choices=LLM_MODELS, help="LLM for I2T Debiasing")
    parser.add_argument('--debias_i2t_llm_prompt', default='none', choices=llm_prompt_dict.keys(), type=str, help='Prompt for LLM I2T Debiasing (BART, FLAN, OPT)')
    
    parser.add_argument("--debias_i2t_laion_subset", default="subset_1000_01", type=str, 
                        help="subset for I2T Debiasing")
    
    parser.add_argument("--output_dir", default="./results", type=str)
    return parser.parse_args()


def get_score_config(args):
    name = f"{args.dataset}_{args.model_name.replace('/','-')}"
    if args.model_name in BLIP_MODELS:
        kwargs = {'mode': args.inference_mode}
        name += f"_{args.inference_mode}"
    elif args.model_name in LLM_MODELS:
        kwargs = {'prompt': llm_prompt_dict[args.llm_prompt]}
        name += f"_{args.llm_prompt}"
    else:
        kwargs = {}
    return name, kwargs

def get_debias_i2t_score_config(args, preprocess=None):
    name = f"debias_i2t_{args.dataset}_{args.model_name.replace('/','-')}_{args.debias_i2t}"
    kwargs = {'method': args.debias_i2t}
    if args.model_name in BLIP_MODELS:
        if args.debias_i2t in ['gaussian']:
            kwargs.update({
                'num_gaussian': args.num_gaussian,
                'mean_gaussian': args.mean_gaussian,
                'std_gaussian': args.std_gaussian,
                'seed_gaussian': args.seed_gaussian
            })
            name += f"_num_{args.num_gaussian}_m_{args.mean_gaussian}_std_{args.std_gaussian}_seed_{args.seed_gaussian}"
        elif args.debias_i2t in ['llm']:
            kwargs.update({
                'llm': args.debias_i2t_llm,
                'prompt': llm_prompt_dict[args.debias_i2t_llm_prompt],
            })
            name += f"_{args.debias_i2t_llm}_prompt_{args.debias_i2t_llm_prompt}"
        elif args.debias_i2t in ['laion']:
            kwargs.update({
                'split': args.debias_i2t_laion_subset,
                'root_dir': args.root_dir,
                'preprocess': preprocess,
            })
            name += f"_{args.debias_i2t_laion_subset}"
    else:
        raise NotImplementedError()
    return name, kwargs

def debias(args, model, loader, scores, preprocess=None):
    if type(scores) is not tuple:
        t2i_scores = scores
        i2t_scores = scores
    else:
        t2i_scores, i2t_scores = scores
    if args.model_name in BLIP_MODELS and args.inference_mode == 'lm':
        debias_i2t_score_name, debias_i2t_score_kwargs = get_debias_i2t_score_config(args, preprocess=preprocess)
        debias_i2t_score_file = os.path.join(args.output_dir, f"{debias_i2t_score_name}.pth")
        if os.path.exists(debias_i2t_score_file):
            print(f"Loading debias_i2t_score from {debias_i2t_score_file}.")
            debias_i2t_score = torch.load(debias_i2t_score_file)
        else:
            debias_i2t_score = model.get_debias_i2t_scores_batched(loader, **debias_i2t_score_kwargs)
            torch.save(debias_i2t_score, debias_i2t_score_file)
        
        i2t_scores = i2t_scores / debias_i2t_score**args.alpha
        
    return (t2i_scores, i2t_scores)

def main():
    args = config()
    
    print(f"Testing {args.model_name} on {args.dataset} dataset.")
    model, preprocess = get_model(model_name=args.model_name, device="cuda", root_dir=args.root_dir)
    model.model.eval()
    print(f"Model is in evaluation mode.")

    dataset = get_dataset(args.dataset, image_preprocess=preprocess, download=True, root_dir=args.root_dir)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
    score_name, score_kwargs = get_score_config(args)
    score_file = os.path.join(args.output_dir, f"{score_name}.pth")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if os.path.exists(score_file):
        print(f"Loading scores from {score_file}.")
        scores = torch.load(score_file)
    else:
        scores = model.get_scores_batched(loader, **score_kwargs)
        torch.save(scores, score_file)
    
    if not args.model_name in BLIP_MODELS_RANDOM:
        scores = debias(args, model, loader, scores, preprocess=preprocess)

    result_record, macro_accuracy = dataset.evaluate_scores(scores)

    if args.dataset == "EqBen_All":
        import numpy as np
        i2t_scores = scores[1]
        if args.inference_mode == 'lm' and args.alpha != 0.0:
            np.save(f'eqben_results/{score_name}_{get_debias_i2t_score_config(args)}.npy', i2t_scores)
        else:
            np.save(f'eqben_results/{score_name}.npy', i2t_scores)
        print(f"Saved i2t scores to eqben_results/{score_name}.npy.")

if __name__ == "__main__":
    main()
