import argparse
import collections

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.metric import RecallAtK

import dataset_loaders.dataset_loaders as module_data
import model.model as module_arch
from utils.parse_config import ConfigParser
import numpy as np
import json

def main(config, checkpoint_path, device='cuda'):
    logger = config.get_logger("test")
    
    dataset = config.init_obj("dataset", module_data, train=False, test=True)
    
    branch_to_adapt = config['arch']['args'].get('branch_to_adapt_val', None)
    comment_fusion = config['arch']['args'].get('comment_fusion', None)
    num_comms = config['dataset']['args'].get('num_comms', None)
    num_imlabels = config['dataset']['args'].get('num_imlabels', None)
    random_words = config['dataset']['args'].get('random_words', False)
    add_comments = config['dataset']['args']['add_comments']

    if branch_to_adapt is None:
        if add_comments != "always":
            exp_combo = f"title_only"
        else:
            exp_combo = f"{comment_fusion}_{num_comms}_comms_{num_imlabels}_imlabels"
    elif random_words:
        exp_combo = f"adapted_{branch_to_adapt}_random_words"
    else:
        exp_combo = f"adapted_{branch_to_adapt}_{num_comms}_comms_{num_imlabels}_imlabels"

    if checkpoint_path is not None:
        save_path = f"{checkpoint_path.absolute().as_posix()[:-4]}_res_{exp_combo}.json"
    else:
        save_path = f"zero_shot_res_{comment_fusion}.json"
    print(f'Saving results to {save_path}')

    data_loader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        num_workers=10,
        shuffle=False,
    )

    # build model architecture, then print to console
    model = config.init_obj("arch", module_arch)
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    logger.info(model)

    model = model.to(device)

    res_vis = []
    res_text = []
    ids = []
    for items in tqdm(data_loader):
        vis, title, comments, meta = items
        with torch.no_grad():
            if add_comments == "always":
                feats_vis, feats_text, sim = model.forward(
                    torch.squeeze(vis).to(device),
                    torch.squeeze(title).to(device),
                    comments.to(device),
                    )
            else:
                feats_vis, feats_text, sim = model.forward(
                    torch.squeeze(vis).to(device),
                    torch.squeeze(title).to(device),
                    )
        res_vis.extend(feats_vis.cpu().detach().numpy())
        res_text.extend(feats_text.cpu().detach().numpy())
        ids.extend(meta["id"].cpu().detach().numpy())
    
    res_vis = np.stack(res_vis)
    res_text = np.stack(res_text)
    
    recall_all_k_title_from_im = RecallAtK('images','titles', [1, 5, 10]).compute(res_vis, res_text)
    recall_all_k_im_from_title = RecallAtK('titles','images', [1, 5, 10]).compute(res_text, res_vis)
            
    print('Recall im from title: ', recall_all_k_im_from_title)
    print('Recall title from im: ', recall_all_k_title_from_im)
    
    out = {"R1_title_from_im": recall_all_k_title_from_im[0][1],
           "R5_title_from_im": recall_all_k_title_from_im[1][1],
           "R10_title_from_im": recall_all_k_title_from_im[2][1],
           "R1_im_from_title": recall_all_k_im_from_title[0][1],
           "R5_im_from_title": recall_all_k_im_from_title[1][1],
           "R10_im_from_title": recall_all_k_im_from_title[2][1]}

    with open(save_path, 'w') as f:
        json.dump(out, f)


if __name__ == "__main__":
    args = argparse.ArgumentParser(description="PyTorch Template")
    args.add_argument(
        "-c",
        "--config",
        default='configs/pretrained_clip.jsonc',
        type=str,
        help="config file path (default: None)",
    )
    args.add_argument(
        "-r",
        "--resume",
        default=None,
        type=str,
        help="path to checkpoint (default: None)",
    )
    args.add_argument(
        "-d",
        "--device",
        default='3',
        type=str,
        help="indices of GPUs to enable (default: all)",
    )
    CustomArgs = collections.namedtuple("CustomArgs", "flags type target")
    options = [
        CustomArgs(["--lr", "--learning_rate"], type=float, target="optimizer;args;lr"),
        CustomArgs(
            ["--bs", "--batch_size"], type=int, target="batch_size",
        ),
        CustomArgs(["--bv", "--branch_to_adapt_val"], type=str, target="arch;args;branch_to_adapt_val"),
        CustomArgs(["--nc", "--num_comms"], type=str, target="dataset;args;num_comms"),
        CustomArgs(["--nl", "--num_imlabels"], type=str, target="dataset;args;num_imlabels"),
        CustomArgs(["--rw", "--random_words"], type=bool, target="dataset;args;random_words"),
        CustomArgs(["--am", "--comment_fusion"], type=str, target="arch;args;comment_fusion"),
        CustomArgs(["--ac", "--add_comments"], type=str, target="dataset;args;add_comments")
    ]
    config = ConfigParser.from_args(args, options)
    _args = args.parse_args()

    main(config, config.resume, device='cuda:'+_args.device)
