#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""Multi-view test a video classification model."""

import numpy as np
import os
import pickle
import torch
from iopath.common.file_io import g_pathmgr

import slowfast.utils.checkpoint as cu
import slowfast.utils.distributed as du
import slowfast.utils.logging as logging
import slowfast.utils.misc as misc
# import slowfast.visualization.tensorboard_vis as tb
# from slowfast.datasets import loader
from slowfast.models import build_model
from slowfast.utils.meters import AVAMeter, TestMeter

import sys
sys.path.append('/home/siddiqui/VideoMamba/videomamba/video_sm/')

from kinetics_dataloader import KineticsDL
from smth_loader import SmthSmthDL
from flexible_dataloader import FlexibleDataLoader
from kinetics_dataloader import KineticsDL, multiple_samples_collate
from COIN_loader import COINDL
from smth_loader import SmthSmthDL
from ucf_dataloader import UCFDL
from breakfast_loader import BkfstDL
from hmdb_loader import HMDBDL
from ntu_loader import NTU120DL
from diving_loader import DivingDL
from tqdm import tqdm
from torch.utils.data import DataLoader

logger = logging.get_logger(__name__)


def run_train_test_ret(model, train_loader, val_loader):
    model.cuda()
    model.eval()
    train_features = []
    test_features = []
    train_labels = []
    labels_list = []

    for frames, labels in tqdm(train_loader):
        frames, labels = frames.cuda(), labels.cuda()
        feat = model(frames)

        zipped = zip(feat, labels)
        for feature, lbl in zipped:
            train_features.append(feature.detach().cpu())
            train_labels.append(lbl.detach().cpu().item())

    for frames, labels in tqdm(val_loader):
        frames, labels = frames.cuda(), labels.cuda()
        feat = model(frames)
        zipped = zip(feat, labels)
        for feature, lbl in zipped:
            test_features.append(feature.detach().cpu())
            labels_list.append(lbl.detach().cpu().item())

    train_features = torch.stack(train_features)
    test_features = torch.stack(test_features)

    correct = 0
    print(train_features.shape, test_features.shape)
    for i, probe in enumerate(tqdm(test_features)):
        probe_sim = torch.nn.CosineSimilarity()(probe.detach().cpu(), train_features.detach().cpu())
        first, arg = torch.topk(probe_sim.flatten(), 2).indices
        # print(labels_list[i] == labels_list[arg.item()])
        if labels_list[i] == train_labels[first.item()]:
            correct += 1

    print(i)
    accuracy = correct / i
    print(correct, i)
    print(f'Test Accuracy: {accuracy}')
    return accuracy



@torch.no_grad()
def perform_test(test_loader, model, test_meter, cfg, writer=None):
    """
    For classification:
    Perform mutli-view testing that uniformly samples N clips from a video along
    its temporal axis. For each clip, it takes 3 crops to cover the spatial
    dimension, followed by averaging the softmax scores across all Nx3 views to
    form a video-level prediction. All video predictions are compared to
    ground-truth labels and the final testing performance is logged.
    For detection:
    Perform fully-convolutional testing on the full frames without crop.
    Args:
        test_loader (loader): video testing loader.
        model (model): the pretrained video model to test.
        test_meter (TestMeter): testing meters to log and ensemble the testing
            results.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        writer (TensorboardWriter object, optional): TensorboardWriter object
            to writer Tensorboard log.
    """
    # Enable eval mode.
    model.eval()
    test_meter.iter_tic()

    for cur_iter, (inputs, labels, video_idx, meta) in enumerate(test_loader):
        if cfg.NUM_GPUS:
            # Transfer the data to the current GPU device.
            if isinstance(inputs, (list,)):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)

            # Transfer the data to the current GPU device.
            labels = labels.cuda()
            video_idx = video_idx.cuda()
            for key, val in meta.items():
                if isinstance(val, (list,)):
                    for i in range(len(val)):
                        val[i] = val[i].cuda(non_blocking=True)
                else:
                    meta[key] = val.cuda(non_blocking=True)
        test_meter.data_toc()

        if cfg.DETECTION.ENABLE:
            # Compute the predictions.
            preds = model(inputs, meta["boxes"])
            ori_boxes = meta["ori_boxes"]
            metadata = meta["metadata"]

            preds = preds.detach().cpu() if cfg.NUM_GPUS else preds.detach()
            ori_boxes = (
                ori_boxes.detach().cpu() if cfg.NUM_GPUS else ori_boxes.detach()
            )
            metadata = (
                metadata.detach().cpu() if cfg.NUM_GPUS else metadata.detach()
            )

            if cfg.NUM_GPUS > 1:
                preds = torch.cat(du.all_gather_unaligned(preds), dim=0)
                ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes), dim=0)
                metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0)

            test_meter.iter_toc()
            # Update and log stats.
            test_meter.update_stats(preds, ori_boxes, metadata)
            test_meter.log_iter_stats(None, cur_iter)
        else:
            # Perform the forward pass.
            if cfg.TEST.ADD_SOFTMAX:
                preds = model(inputs).softmax(-1)
            else:
                preds = model(inputs)

            # Gather all the predictions across all the devices to perform ensemble.
            if cfg.NUM_GPUS > 1:
                preds, labels, video_idx = du.all_gather(
                    [preds, labels, video_idx]
                )
            if cfg.NUM_GPUS:
                preds = preds.cpu()
                labels = labels.cpu()
                video_idx = video_idx.cpu()

            test_meter.iter_toc()
            # Update and log stats.
            test_meter.update_stats(
                preds.detach(), labels.detach(), video_idx.detach()
            )
            test_meter.log_iter_stats(cur_iter)

        test_meter.iter_tic()

    # Log epoch stats and print the final testing results.
    if not cfg.DETECTION.ENABLE:
        all_preds = test_meter.video_preds.clone().detach()
        all_labels = test_meter.video_labels
        if cfg.NUM_GPUS:
            all_preds = all_preds.cpu()
            all_labels = all_labels.cpu()
        if writer is not None:
            writer.plot_eval(preds=all_preds, labels=all_labels)

        if cfg.TEST.SAVE_RESULTS_PATH != "":
            save_path = os.path.join(cfg.OUTPUT_DIR, cfg.TEST.SAVE_RESULTS_PATH)

            if du.is_root_proc():
                with g_pathmgr.open(save_path, "wb") as f:
                    pickle.dump([all_preds, all_labels], f)

            logger.info(
                "Successfully saved prediction results to {}".format(save_path)
            )

    test_meter.finalize_metrics()
    return test_meter


def test(cfg):
    """
    Perform multi-view testing on the pretrained video model.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Set up environment.
    du.init_distributed_training(cfg)
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.OUTPUT_DIR)

    # Print config.
    logger.info("Test with config:")
    logger.info(cfg)

    # Build the video model and print model statistics.
    model = build_model(cfg)
    model.load_state_dict(torch.load(cfg.TEST.CHECKPOINT_FILE_PATH))
    print('weights loaded!')
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {total_params}")
    exit()

    bs = 8
    resolution = 224
    # if args.dataset == 'kinetics':
    #     train_dataset = KineticsDL('train', num_frames=args.num_frames, flexible=False)
    #     test_dataset = KineticsDL('test', num_frames=args.num_frames, flexible=False)
    # elif args.dataset == 'coin':
    #     train_dataset = COINDL('train', num_frames=args.num_frames, resolution=resolution)
    #     test_dataset = COINDL('test', num_frames=args.num_frames, resolution=resolution)
    #     nb_classes = 180
    #
    train_dataset = UCFDL('train', num_frames=8, resolution=resolution)
    test_dataset = UCFDL('test', num_frames=8, resolution=resolution)
    nb_classes = 101
    #
    # elif args.dataset == 'smthsmth':
    #     train_dataset = SmthSmthDL('train', num_frames=args.num_frames, flexible=False)
    #     test_dataset = SmthSmthDL('test', num_frames=args.num_frames, flexible=False)

    # train_dataset = BkfstDL('train', num_frames=8, resolution=resolution)
    # test_dataset = BkfstDL('test', num_frames=8, resolution=resolution)
    # nb_classes = 10

    # train_dataset = HMDBDL('train', num_frames=8, resolution=resolution)
    # test_dataset = HMDBDL('test', num_frames=8, resolution=resolution)
    # nb_classes = 51
    #
    # elif args.dataset == 'ntu':
    #     train_dataset = NTU120DL('train', num_frames=args.num_frames, resolution=resolution)
    #     test_dataset = NTU120DL('test', num_frames=args.num_frames, resolution=resolution)
    #     nb_classes = 120
    #
    # elif args.dataset == 'diving':
    #     train_dataset = DivingDL('train', num_frames=args.num_frames, resolution=resolution)
    #     test_dataset = DivingDL('test', num_frames=args.num_frames, resolution=resolution)
    #     nb_classes = 48

    # train_dataset = COINDL('train', num_frames=8, resolution=resolution)
    # test_dataset = COINDL('test', num_frames=8, resolution=resolution)
    # nb_classes = 180

    test_loader = DataLoader(test_dataset, num_workers=8, batch_size=bs, shuffle=False,
                             collate_fn=multiple_samples_collate)
    train_loader = DataLoader(train_dataset, num_workers=8, batch_size=bs, shuffle=False,
                              collate_fn=multiple_samples_collate)
    run_train_test_ret(model, train_loader, test_loader)
