import argparse
import logging

import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
print("DISABLED TF GPUS")

import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import clip
from tqdm import tqdm

import dataset_loaders.dataset_loaders as module_data
import model.model as module_arch
from model.metrics_deferf import (rank_at_k_precomputed,
                                  tensor_text_to_video_metrics,
                                  tensor_video_to_text_sim)

logging.getLogger().setLevel("INFO")


def compute_recall(tensor_v, tensor_t, split="full-test"):
    sim = tensor_t @ tensor_v.t()
    vtt_sim = tensor_video_to_text_sim(sim)

    vtt_metrics = rank_at_k_precomputed(vtt_sim)
    ttv_metrics = tensor_text_to_video_metrics(sim)
    
    df = pd.DataFrame({f"MSRVTT {split} split Video to Text": vtt_metrics,
                       f"MSRVTT {split} split Text to Video": ttv_metrics})
    logging.info(df)
    return df


models_needing_comments = (module_arch.PretrainedCLIP_finaltf,
                           module_arch.PretrainedCLIP_TimeSformer_finaltf)

image_models = (module_arch.PretrainedCLIP,
                module_arch.PretrainedCLIP_finaltf)

video_models = (module_arch.PretrainedCLIP_TimeSformer,
                module_arch.PretrainedCLIP_TimeSformer_finaltf)


def load_model(checkpoint_path, device, model_type):
    # TODO switch to load from config
    if model_type == "pretrained_clip":
        model = module_arch.PretrainedCLIP(
            model_type="ViT-B/32",
            freeze=False,
            residual_activation=args.residual_activation
        )
    elif model_type == "clip_timesformer":
        model = module_arch.PretrainedCLIP_TimeSformer(
            residual_activation=args.residual_activation
        )
    elif model_type == "pretrained_clip_finaltf":
        model = module_arch.PretrainedCLIP_finaltf(
            branch_to_adapt_val=args.branch_to_adapt,
            residual_activation=args.residual_activation
        )
    elif model_type == "clip_timesformer_finaltf":
        model = module_arch.PretrainedCLIP_TimeSformer_finaltf(
            branch_to_adapt_val=args.branch_to_adapt,
            residual_activation=args.residual_activation
        )

    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path,
                                map_location="cpu")

        model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    model.to(device)
    return model


@torch.no_grad()
def retrieval_evaluation(model, datasetname, split, device, out_csv=None, frame_stride=16):
    if datasetname == "MSRVTT_videos":
        dataset = module_data.VideoDatasetMSRVTT(
            train=False,
            split=split
        )
    elif datasetname == "MSVD_videos":
        dataset = module_data.VideoDatasetMSVD(
            train=False,
            split=split
        )
    else:
        raise Exception("Unknown dataset")

    data_loader = DataLoader(
        dataset,
        batch_size=1,
        num_workers=4,
        shuffle=False,
    )

    video_joint_embeddings = []
    caption_joint_embeddings = []

    logging.info("Computing joint embeddings")

    for items in tqdm(data_loader):
        frames, captions, _ = items

        frames = frames.to(device)
        captions = captions.to(device)

        assert len(captions.shape) == 3 and captions.shape[0] == 1
        assert len(frames.shape) == 5 and frames.shape[0] == 1 and frames.shape[2] == 3

        # [1, ncaptions, 77] --> [ncaptions, 77]
        captions = captions[0]

        if isinstance(model, image_models):
            # [1, nframes, nchans, h, w] -> [nframes, nchans, h, w]
            frames = frames[0]

        elif isinstance(model, video_models):
            # Split into batches of 8 frames
            # [1, nframes, nchans, h, w] -> [nchunks, 8, nchans, h, w]
            nframes = 8

            frames = frames[:, ::frame_stride]
            splits = torch.split(frames, nframes, 1)
            splits_pad = [x if x.shape[1] == nframes else
                          torch.index_select(x, dim=1, index=torch.floor(torch.linspace(
                              0, x.shape[1] - 1, nframes, device=device)).to(torch.int64))
                          for x in splits]
            chunks = torch.cat(splits_pad, dim=0)
            frames = chunks
        else:
            raise Exception("Unknown model_type")

        if isinstance(model, models_needing_comments):
            # Still not sure what is best to do here
            # (mask-only comment, repeat title, skip adapting)
            # NB empty string will be replaced with the model's mask token
            if model.branch_to_adapt_val == "image":
                ncomms = len(frames)
            else:
                ncomms = len(captions)
                
            dummy_comments = torch.stack([clip.tokenize(["","","","",""]) for _ in range(ncomms)]).to(device)
            feats_a, feats_b, sim = model.forward(frames, captions, dummy_comments)
        else:
            feats_a, feats_b, sim = model.forward(frames, captions)

        video_joint_embeddings.append(feats_a.cpu().detach())
        caption_joint_embeddings.append(feats_b.cpu().detach())

    # pad captions tensor for when there's a different # of captions per video
    max_length = max([s.shape[0] for s in caption_joint_embeddings])
    padded_caption_joint_embeddings = [
        torch.cat([k, torch.full((max_length - k.shape[0], k.shape[1]), float("-inf"), device = k.device)]) 
        for k in caption_joint_embeddings
        ]
    # take average of frame features 
    video_joint_tensor = torch.cat([torch.mean(torch.tensor(k), dim = 0, keepdim = True) 
        for k in video_joint_embeddings])
    caption_joint_tensor = torch.stack(padded_caption_joint_embeddings)

    outdf = compute_recall(video_joint_tensor, caption_joint_tensor, split=split)

    if out_csv is not None:
        outdf.to_csv(out_csv)
    return outdf


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument(
        "-c",
        "--dataset",
        default="MSRVTT_videos",
        choices=["MSRVTT_videos", "MSVD_videos"],
        type=str,
        help="dataset to load",
    )
    args.add_argument(
        "-r",
        "--checkpoint",
        default=None,
        type=str,
        help="path to checkpoint (default: None)",
    )
    args.add_argument(
        "-m",
        "--model_type",
        default=None,
        type=str,
        help="model arch to be loaded",
    )
    args.add_argument(
        "-d",
        "--device",
        default='cuda',
        type=str,
        help="device to load model on",
    )
    args.add_argument(
        "-s",
        "--split",
        default="full-test",
        type=str,
        help="which test split to use",
    )
    args.add_argument(
        "--branch_to_adapt",
        default="text",
        choices=["text", "image", "random", "skip"],
        type=str,
        help="which branch to adapt for finaltf models",
    )
    args.add_argument(
        "--residual_activation",
        default="none",
        type=str,
        help="which activation fn to use on the residual",
    )
    args.add_argument(
        "--out_csv",
        default=None,
        type=str,
        help="File to save output csv",
    )
    args.add_argument(
        "--frame_stride",
        default=16,
        type=int,
        help="Video frame stride",
    )
    args = args.parse_args()
    assert args.dataset in ["MSRVTT_videos", "MSVD_videos"]

    model = load_model(args.checkpoint, args.device,
                       model_type=args.model_type)

    retrieval_evaluation(model, args.dataset, args.split, args.device,
                         out_csv=args.out_csv, frame_stride=args.frame_stride)
