import argparse
import os
import time
import pandas as pd
import torch
from torch.utils.data import DataLoader
import numpy as np
from model_zoo import get_model, BLIP_MODELS, LLM_MODELS
from dataset_zoo import get_dataset
from dataset_zoo.perturbations import get_text_perturb_fn, get_image_perturb_fn
from misc import seed_all, _default_collate, save_scores

from main import debias_prompt_dict, llm_prompt_dict


def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--root_dir", default="/scratch/visual_gpt/", type=str,
                        help='Root directory for data. Will download VG-Relation and VG-Attribution and COCO images here.')
    parser.add_argument("--dataset", default="COCO", type=str, choices=["COCO_Retrieval", "Flickr30k_Retrieval", "COCO_Retrieval_Val", "Flickr30k_Retrieval_Val",
                                                                        'Laion_Retrieval_subset_100_00', 'Laion_Retrieval_subset_100_01',
                                                                        'Laion_Retrieval_subset_500_00', 'Laion_Retrieval_subset_500_01',
                                                                        'Laion_Retrieval_subset_1000_00', 'Laion_Retrieval_subset_1000_01',
                                                                        'Laion_Retrieval_subset_2000_00_sum', 'Laion_Retrieval_subset_1000_00_sum', 'Laion_Retrieval_subset_5000_00_sum',
                                                                        'Laion_Retrieval_subset_2000_00', 'Laion_Retrieval_subset_2000_01',
                                                                        'Laion_Retrieval_subset_5000_00', 'Laion_Retrieval_subset_5000_01',])
    parser.add_argument("--device", default="cuda", type=str)
    parser.add_argument("--batch_size", default=16, type=int)
    parser.add_argument("--model_name", default="openai-clip:ViT-B/32", type=str)
    parser.add_argument("--inference_mode", default='lm', type=str, choices=['itc', 'itm', 'lm'], help="Inference mode.")
    parser.add_argument('--llm_prompt', default='none', choices=llm_prompt_dict.keys(), type=str, help='Prompt for LLM (BART, FLAN, OPT)')
    
    parser.add_argument("--alpha", default=0.0, type=float, help="Alpha for Debiasing")
    
    parser.add_argument("--debias_i2t", default="gaussian", type=str, choices=['gaussian', 'llm', 'monte_carlo'],
                        help="Method for I2T Debiasing for GPTScore")
    parser.add_argument("--num_gaussian", default=3, type=int, help="Number of Random Images for I2T Debiasing")
    parser.add_argument("--mean_gaussian", default=0.45, type=float, help="Mean for Random Image for I2T Debiasing")
    parser.add_argument("--std_gaussian", default=0.25, type=float, help="Std for Random Image for I2T Debiasing")
    parser.add_argument("--seed_gaussian", default=1, type=int, help="Random Image Seed for I2T Debiasing")
    parser.add_argument("--debias_i2t_llm", default="flan-t5-xl", type=str, choices=LLM_MODELS, help="LLM for I2T Debiasing")
    parser.add_argument('--debias_i2t_llm_prompt', default='none', choices=llm_prompt_dict.keys(), type=str, help='Prompt for LLM I2T Debiasing (BART, FLAN, OPT)')
    
    parser.add_argument("--debias_t2i", default="prompt", type=str, choices=['prompt', 'entropy', 'monte_carlo'],
                        help="Method for T2I Debiasing for GPTScore")
    parser.add_argument('--debias_t2i_prompt', default='common_photo', choices=debias_prompt_dict.keys(), type=str, help='Prompt for T2I Debiasing')
    
    parser.add_argument("--output_dir", default="./results", type=str)
    
    # different from main.py
    parser.add_argument("--num_workers", default=4, type=int)
    parser.add_argument("--text_perturb_fn", default=None, type=str, help="Perturbation function to apply to the text.")
    parser.add_argument("--image_perturb_fn", default=None, type=str, help="Perturbation function to apply to the images.")
    parser.add_argument("--reranking", default=0, type=int, help="If 0, no reranking; otherwise rerank with top-k queries using ITC score.")
    return parser.parse_args()

def get_dataset_name(args):
    name = f"{args.dataset}"
    if args.text_perturb_fn is not None:
        name += f"_text_perturb_{args.text_perturb_fn}"
    if args.image_perturb_fn is not None:
        name += f"_image_perturb_{args.image_perturb_fn}"
    return name

def get_model_name(args):
    name = f"{args.model_name.replace('/', '-')}"
    if args.reranking > 0:
        name += f"_rerank_{args.reranking}"
    return name

def get_score_config(args):
    name = get_dataset_name(args)
    name += f"_{get_model_name(args)}"
    
    kwargs = {'reranking': args.reranking}
    if args.model_name in BLIP_MODELS:
        kwargs.update({'mode': args.inference_mode})
        name += f"_{args.inference_mode}"
    elif args.model_name in LLM_MODELS:
        import pdb; pdb.set_trace()
        # kwargs.update({'prompt': llm_prompt_dict[args.llm_prompt]})
        # name += f"_{args.llm_prompt}"
    else:
        kwargs = {}
    return name, kwargs

def get_debias_i2t_score_config(args):
    name = f"debias_i2t_{get_dataset_name(args)}_{get_model_name(args)}_{args.debias_i2t}"
    kwargs = {'method': args.debias_i2t, 'reranking': args.reranking}
    if args.model_name in BLIP_MODELS:
        if args.debias_i2t in ['gaussian']:
            kwargs.update({
                'num_gaussian': args.num_gaussian,
                'mean_gaussian': args.mean_gaussian,
                'std_gaussian': args.std_gaussian,
                'seed_gaussian': args.seed_gaussian
            })
            name += f"_num_{args.num_gaussian}_m_{args.mean_gaussian}_std_{args.std_gaussian}_seed_{args.seed_gaussian}"
        elif args.debias_i2t in ['llm']:
            kwargs.update({
                'llm': args.debias_i2t_llm,
                'prompt': llm_prompt_dict[args.debias_i2t_llm_prompt],
            })
            name += f"_{args.debias_i2t_llm}_prompt_{args.debias_i2t_llm_prompt}"
    else:
        raise NotImplementedError()
    return name, kwargs

def get_debias_t2i_score_config(args):
    name = f"debias_t2i_{get_dataset_name(args)}_{get_model_name(args)}_{args.debias_t2i}"
    kwargs = {'method': args.debias_t2i, 'reranking': args.reranking}
    if args.model_name in BLIP_MODELS:
        if args.debias_t2i in ['prompt']:
            kwargs.update({'debias_prompt': debias_prompt_dict[args.debias_t2i_prompt]})
            name += f"_{args.debias_t2i_prompt}"
        elif args.debias_t2i in ['entropy']:
            pass
        else:
            raise NotImplementedError()
    return name, kwargs

def debias(args, model, loader, scores):
    i2t_scores, t2i_scores = scores # Important: This is different from main.py
    if args.debias_i2t in ['monte_carlo']:
        assert args.debias_i2t == 'monte_carlo'
        assert args.debias_t2i == 'monte_carlo'
        assert args.reranking == 0
        n_images = i2t_scores.shape[0]
        n_caption = i2t_scores.shape[1]
        assert t2i_scores.shape[0] == n_caption
        assert t2i_scores.shape[1] == n_images
        
        debias_i2t_score = np.tile(np.mean(i2t_scores, axis=0), (n_images,1))
        i2t_scores = i2t_scores / debias_i2t_score**args.alpha
        
        debias_t2i_score = np.tile(np.mean(t2i_scores, axis=0), (n_caption,1))
        t2i_scores = t2i_scores / debias_t2i_score**args.alpha
        return (i2t_scores, t2i_scores)
        
    if args.model_name in BLIP_MODELS and args.inference_mode == 'lm':
        debias_i2t_score_name, debias_i2t_score_kwargs = get_debias_i2t_score_config(args)
        debias_i2t_score_file = os.path.join(args.output_dir, f"{debias_i2t_score_name}.pth")
        if os.path.exists(debias_i2t_score_file):
            print(f"Loading debias_i2t_score from {debias_i2t_score_file}.")
            debias_i2t_score = torch.load(debias_i2t_score_file)
        else:
            debias_i2t_score = model.get_debias_i2t_scores_retrieval(loader, **debias_i2t_score_kwargs)
            torch.save(debias_i2t_score, debias_i2t_score_file)
        mask = (debias_i2t_score > 0)
        # import numpy as np
        # print(i2t_scores[0][mask[0]])
        # print(debias_i2t_score[0][mask[0]])
        # print(np.argsort(i2t_scores[0][mask[0]]))
        # # print(np.argsort(i2t_scores[1][mask[1]]))
        # print()
        i2t_scores[mask] = i2t_scores[mask] / debias_i2t_score[mask]**args.alpha
        # print(i2t_scores[0][mask[0]])
        # print(np.argsort(i2t_scores[0][mask[0]]))
        # print(np.argsort(i2t_scores[1][mask[1]]))
        # import pdb; pdb.set_trace()
        
        # debias_t2i_score_name, debias_t2i_score_kwargs = get_debias_t2i_score_config(args)
        # debias_t2i_score_file = os.path.join(args.output_dir, f"{debias_t2i_score_name}.pth")
        # if os.path.exists(debias_t2i_score_file):
        #     print(f"Loading debias_t2i_score from {debias_t2i_score_file}.")
        #     debias_t2i_score = torch.load(debias_t2i_score_file)
        # else:
        #     debias_t2i_score = model.get_debias_t2i_scores_retrieval(loader, **debias_t2i_score_kwargs)
        #     torch.save(debias_t2i_score, debias_t2i_score_file)
        
        # mask = (debias_t2i_score > 0)
        # t2i_scores[mask] = t2i_scores[mask] / debias_t2i_score[mask]**args.alpha
    return (i2t_scores, t2i_scores) # Important: This is different from main.py

def main():
    args = config()
    
    print(f"Testing {args.model_name} with reranking top-k {args.reranking} on {args.dataset} dataset.")
    model, image_preprocess = get_model(model_name=args.model_name, device="cuda", root_dir=args.root_dir)
    model.model.eval()
    print(f"Model is in evaluation mode.")
    
    text_perturb_fn = get_text_perturb_fn(args.text_perturb_fn)
    image_perturb_fn = get_image_perturb_fn(args.image_perturb_fn)

    dataset = get_dataset(args.dataset, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=True, root_dir=args.root_dir)
    # For some models we just pass the PIL images, so we'll need to handle them in the collate_fn. 
    collate_fn = _default_collate if image_preprocess is None else None
    
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
    
    if text_perturb_fn is not None:
        start = time.time()
        loader.dataset.text = [text_perturb_fn(t) for t in loader.dataset.text]
        end = time.time()
        print(f"Text perturbation took {end - start:.2f} seconds.")
        
    score_name, score_kwargs = get_score_config(args)
    score_file = os.path.join(args.output_dir, f"{score_name}.pth")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if os.path.exists(score_file):
        print(f"Loading scores from {score_file}.")
        scores = torch.load(score_file)
    else:
        print(f"Computing scores for {score_file}.")
        scores = model.get_scores_retrieval(loader, **score_kwargs)
        torch.save(scores, score_file)
    
    scores = debias(args, model, loader, scores)

    _ = dataset.evaluate_scores(scores)

if __name__ == "__main__":
    main()