import timeit

import cv2
import json
import re
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import decord
import statistics
import argparse
import hiera
import pickle
import os

from decord import VideoReader

decord.bridge.set_bridge('torch')
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torchvision.models.video import r3d_18, R3D_18_Weights, swin3d_b, Swin3D_B_Weights, mvit_v1_b
from tqdm import tqdm
from timm.models import create_model
from models.AIM.AIM_vit import ViT_CLIP
from models.st_adapter.st_adapter import clip_vit_base_patch16_adapter12x384
from models.EVL.EVL import EVLTransformer
from models.flex_mvit import FlexMViT

from models import *
from models.FluxViT.single_modality.models.fluxvit import FluxViT
# from dataloader import omniDataLoader

from kinetics_dataloader import KineticsDL
from smth_loader import SmthSmthDL
from flexible_dataloader import FlexibleDataLoader
from kinetics_dataloader import KineticsDL, multiple_samples_collate
from COIN_loader import COINDL
from smth_loader import SmthSmthDL
from ucf_dataloader import UCFDL
from breakfast_loader import BkfstDL
from hmdb_loader import HMDBDL
from ntu_loader import NTU120DL
from diving_loader import DivingDL
from k700_loader import K700
from lvu_loader import LVUDL

from torch.utils.data import DataLoader
from torch.autograd import Variable
import utils
from collections import OrderedDict

import time
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
import numpy as np
# from models.omnivore.omnivore.models.omnivore_model import omnivore_swinB


def run_ret(model, val_loader):
    model.cuda()
    model.eval()
    features = []
    labels_list = []
    for frames, labels in tqdm(val_loader):
        frames, labels = frames.cuda(), labels.cuda()
        feat = model.forward_features(frames)
        zipped = zip(feat, labels)
        for feature, lbl in zipped:
            features.append(feature.detach().cpu())
            labels_list.append(lbl.detach().cpu().item())

    features = torch.stack(features)

    correct = 0
    for i, row in enumerate(tqdm(features)):
        row_sim = torch.nn.CosineSimilarity()(row.detach().cpu(), features.detach().cpu())
        first, arg = torch.topk(row_sim.flatten(), 2).indices
        # print(labels_list[i] == labels_list[arg.item()])
        if labels_list[i] == labels_list[arg.item()]:
            correct += 1
            
    print(i)
    accuracy = correct / i
    print(correct, i)
    print(f'Test Accuracy: {accuracy}')
    return accuracy



def run_train_test_ret(model, train_loader, val_loader, args):
    model.cuda()
    model.eval()
    train_features = []
    test_features = []
    train_labels = []
    labels_list = []
    
    for frames, labels in tqdm(train_loader):
        frames, labels = frames.cuda(), labels.cuda()
        if args.model == 'videomamba':
            # print(frames.shape)
            feat = model.forward_features(frames)
        elif args.model in ['omnivore', 'hiera', 'aim', 'evl', 'st_adapter', 'fluxvit', 'flexmvit']:
            feat = model(frames)
        elif args.model in 'swin':
            feat = model(frames)['features'].squeeze(dim=(2,3,4))
        elif args.model == 'mvit':
            feat = model(frames)['features']
            #print(feat.shape)
            

        zipped = zip(feat, labels)
        for feature, lbl in zipped:
            train_features.append(feature.detach().cpu())
            train_labels.append(lbl.detach().cpu().item())
    
    for frames, labels in tqdm(val_loader):
        frames, labels = frames.cuda(), labels.cuda()
        if args.model == 'videomamba':
            feat = model.forward_features(frames)
        elif args.model in ['omnivore', 'hiera', 'aim', 'evl', 'st_adapter', 'fluxvit', 'flexmvit']:
            feat = model(frames)
        elif args.model in 'swin':
            feat = model(frames)['features'].squeeze(dim=(2,3,4))
        elif args.model == 'mvit':
            feat = model(frames)['features']
            #print(feat.shape)
             
        zipped = zip(feat, labels)
        for feature, lbl in zipped:
            test_features.append(feature.detach().cpu())
            labels_list.append(lbl.detach().cpu().item())

    train_features = torch.stack(train_features)
    test_features = torch.stack(test_features)


    # if 'k400' in args.ckpt:
    #     assert int(args.frames) == 16, 'set frames to 16 for fair baseline comps'
    #     torch.save(train_features, f'viz/{args.dataset}/{args.dataset}-{args.resolution}-base-train-feat.pt')
    #     torch.save(test_features,  f'viz/{args.dataset}/{args.dataset}-{args.resolution}-base-test-feat.pt')
    # else:
    #     assert int(args.frames) == 16, 'set frames to 16 for fair baseline comps'
    #     torch.save(train_features, f'viz/{args.dataset}/{args.dataset}-{args.resolution}-st-train-feat.pt')
    #     torch.save(test_features, f'viz/{args.dataset}/{args.dataset}-{args.resolution}-st-test-feat.pt')
    #
    # if not os.path.exists(f'viz/{args.dataset}/{args.dataset}-train_labels.pkl'):
    #     print(f'labels for {args.dataset} does not exist!!!')
    #     pickle.dump(train_labels, open(f'viz/{args.dataset}/{args.dataset}-train_labels.pkl', 'wb'))
    #     pickle.dump(labels_list, open(f'viz/{args.dataset}/{args.dataset}-test_labels.pkl', 'wb'))


    correct = 0
    print(train_features.shape, test_features.shape)
    # dist_mat = []
    for i, probe in enumerate(tqdm(test_features)):
        probe_sim = torch.nn.CosineSimilarity()(probe.unsqueeze(0).detach().cpu(), train_features.detach().cpu())
        # dist_mat.append(probe_sim.cuda())
        first, arg = torch.topk(probe_sim.flatten(), 2).indices
        # print(labels_list[i] == labels_list[arg.item()])
        if labels_list[i] == train_labels[first.item()]:
            correct += 1
    # dist_mat = torch.stack(dist_mat)
    # print(dist_mat.shape)
    # torch.save(dist_mat, 'u-96-base-distmat.pt')
    print(i)
    accuracy = correct / i
    print(correct, i)
    print(f'Test Accuracy: {accuracy}')
    return accuracy


def finetune(args):
    print(args)
    model_name = 'videomamba_middle'
    num_frames = int(args.frames)
    resolution = int(args.resolution)
    cpath = args.ckpt

    checkpoint = torch.load(cpath)
    # static_tokens = True if 'static-tokens' in cpath else False
    # flex_all = True if 'flex_all' in cpath else False
    # flexivit = True if 'flexivit' in cpath else False

    model_frames = 64 if 'f16' not in cpath[:-4] else 16
    # spatial_flex = False if (resolution == 224 or static_tokens or flexivit) else True
    

    bs = 8
    model = create_model(
        model_name,
        img_size=224,
        pretrained=None,
        num_classes=400, # pretrained on kinetics with 400 classes, head is changed later
        fc_drop_rate=0,
        drop_path_rate=0.4,
        kernel_size=1,
        num_frames=model_frames,
        use_checkpoint=True,
        checkpoint_num=0,
        flexible=True,
        spatial_flex=False,
        flexivit=False,
        flex_all=False,
        static_tokens=True,
    )
    
    new_dict = OrderedDict()
    all_keys = list(checkpoint.keys())

    for key in all_keys:
        if key.startswith('module.'):
            new_dict[key[7:]] = checkpoint[key]
        else:
            new_dict[key] = checkpoint[key]
    checkpoint = new_dict
    model.load_state_dict(checkpoint)
    print('weights loaded: ', cpath)
    
    
    if args.dataset == 'kinetics':
        train_dataset = KineticsDL('train', num_frames=num_frames, flexible=False)
        test_dataset = KineticsDL('test', num_frames=num_frames, flexible=False)
    elif args.dataset == 'coin':
        train_dataset = COINDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = COINDL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 180

    elif args.dataset == 'ucf':
        train_dataset = UCFDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = UCFDL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 101

    elif args.dataset == 'SSV2':
        train_dataset = SmthSmthDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = SmthSmthDL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 174

    elif args.dataset == 'breakfast':
        train_dataset = BkfstDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = BkfstDL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 10

    elif args.dataset == 'hmdb':
        train_dataset = HMDBDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = HMDBDL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 51

    elif args.dataset == 'ntu':
        train_dataset = NTU120DL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = NTU120DL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 120

    elif args.dataset == 'diving':
        train_dataset = DivingDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = DivingDL('test', num_frames=num_frames, resolution=resolution)
        nb_classes = 48




    test_loader = DataLoader(test_dataset, num_workers=8, batch_size=bs, shuffle=True,
                             collate_fn=multiple_samples_collate)
    train_loader = DataLoader(train_dataset, num_workers=8, batch_size=bs, shuffle=True,
                              collate_fn=multiple_samples_collate)

    model.head = nn.Linear(576, nb_classes)
    model.cuda()

    if model_frames != num_frames:
        print(f"Temporal interpolate from {model_frames} to {num_frames}")
        temp_pos_embed = model.temporal_pos_embedding.permute(0, 2, 1)
        temp_pos_embed = torch.nn.functional.interpolate(
            temp_pos_embed, size=(num_frames,), mode='linear', align_corners=False
        )
        temp_pos_embed = temp_pos_embed.permute(0, 2, 1)
        model.temporal_pos_embedding = nn.Parameter(temp_pos_embed)
    else:
        print('no interp needed: ', model_frames, num_frames)

    params_to_opt = []

    if args.train == 'probe':
        for n, p in model.named_parameters():
            if 'head' in n:
                p.requires_grad = True
                params_to_opt.append(p)
            else:
                p.requires_grad = False
            print(n, p.requires_grad)
        optimizer = optim.SGD(params_to_opt, lr=0.001, momentum=0.9, weight_decay=5e-4)

    else:
        model.train()
        optimizer = optim.Adam(model.parameters())


    max = 0
    for epoch in tqdm(range(50)):
       losses = []
       for frames, labels in tqdm(train_loader):
           optimizer.zero_grad()
           frames, labels = frames.cuda(), labels.cuda()
           preds = model(frames)
           loss = torch.nn.CrossEntropyLoss()(preds, labels)
           losses.append(loss)
           loss.backward()
           optimizer.step()
       total = 0
       correct = 0
       for frames, labels in tqdm(test_loader):
          total += frames.shape[0]
          frames, labels = frames.cuda(), labels.cuda()
          preds = model(frames)
          outputs = preds.argmax(dim=1)
          # print(outputs.shape)
          print(outputs, labels)
          correct += (outputs == labels).sum()
          # print(correct)
       if correct/total > max:
          max = correct/total
       print(f'final accuracy: {correct/total}, max: {max}')

            


def eval_flexVM(args):
    num_frames = int(args.frames)
    resolution = int(args.resolution)
    ######################### LOAD VM MODEL ####################################
    if args.model == 'videomamba':
        model_name = 'videomamba_middle'
        cpath = args.ckpt
        checkpoint = torch.load(cpath)
        static_tokens = True if 'static-tokens' in cpath else False
        flex_all = True if 'flex_all' in cpath else False
        flexivit = True if 'flexivit' in cpath else False
#        if '.pth' in cpath:
#            model_frames = 8
#            checkpoint = checkpoint['model']
#            nb_classes = 101
        if 'k400' in cpath: #only do this for vm baseline weights
            ch_frames = int(cpath.split('_')[4][1:])
            print(ch_frames)
            model_frames = ch_frames
            nb_classes = 400
        else:
            model_frames = 64
            nb_classes = 400
        #model_frames = 64 if 'f16' not in cpath[:-4] else 16
        spatial_flex = False if (resolution == 224 or static_tokens or flexivit) else True

        model = create_model(
            model_name,
            img_size=224,
            pretrained=None,
            num_classes=nb_classes,
            fc_drop_rate=0,
            drop_path_rate=0.4,
            kernel_size=1,
            num_frames=model_frames,
            use_checkpoint=True,
            checkpoint_num=0,
            flexible=True,
            spatial_flex=spatial_flex,
            flexivit=flexivit,
            flex_all=flex_all,
            static_tokens=static_tokens,
        )



        new_dict = OrderedDict()
        all_keys = list(checkpoint.keys())

        for key in all_keys:
            if key.startswith('module.'):
                new_dict[key[7:]] = checkpoint[key]
            else:
                new_dict[key] = checkpoint[key]
        checkpoint = new_dict
        model.load_state_dict(checkpoint, strict=False)
        print('weights loaded: ', cpath)

        ####################################### Interpolate Temporal Emebeddings ###################################
        if model_frames != num_frames:
            print(f"Temporal interpolate from {model_frames} to {num_frames}")
            temp_pos_embed = model.temporal_pos_embedding.permute(0, 2, 1)
            temp_pos_embed = torch.nn.functional.interpolate(
                temp_pos_embed, size=(num_frames,), mode='linear', align_corners=False
            )
            temp_pos_embed = temp_pos_embed.permute(0, 2, 1)
            model.temporal_pos_embedding = nn.Parameter(temp_pos_embed)
        else:
            print('no interp needed: ', model_frames, num_frames)
        ####################################### Interpolate Temporal Emebeddings ###################################

    ######################### LOAD LVU-VM MODEL ####################################

    ######################### LOAD CUSTOM MODELS ####################################

    elif args.model == 'R3D':
        model = r3d_18(weights='DEFAULT')
        model = create_feature_extractor(model, return_nodes={"avgpool": "features"})

    elif args.model == 'swin':
        model = swin3d_b(weights='DEFAULT')
        model = create_feature_extractor(model, return_nodes={"avgpool": "features"})
    elif args.model == 'mvit':
        model = mvit_v1_b(weights='DEFAULT')
        model = create_feature_extractor(model, return_nodes={"getitem_1": "features"})
    elif args.model == 'flexmvit':
        model = FlexMViT.from_pretrained("v1_b", "DEFAULT")
        cpath = "/lustre/fs1/home/VideoMamba/videomamba/video_sm/exp/k400/videomamba_middle/mvit-flex/checkpoint-flex_best.pth"
        model.load_state_dict(torch.load(cpath))
        model.mode = 'test'
    elif args.model == 'omnivore':
        model_name = "omnivore_swinB"
        model = torch.hub.load("facebookresearch/omnivore:main", model=model_name, force_reload=False)

    elif args.model == 'hiera':
        model = hiera.hiera_large_16x224(pretrained=True, checkpoint="mae_k400_ft_k400")

    elif args.model == 'aim':
        model = ViT_CLIP(input_resolution=resolution, num_frames=16, patch_size=16, width=768, layers=12, heads=12,
                 drop_path_rate=0.2, num_tadapter=1, adapter_scale=0.5, pretrained=None)
        checkpoint = torch.load("models/AIM/vit_b_clip_16frame_k400.pth")
        new_dict = OrderedDict()
        all_keys = list(checkpoint.keys())

        for key in all_keys:
            if key.startswith('backbone.'):
                new_dict[key[9:]] = checkpoint[key]
            elif 'cls_head' in key:
                continue
            else:
                new_dict[key] = checkpoint[key]
        checkpoint = new_dict
        model.load_state_dict(checkpoint)
        print('weights loaded!')

    elif args.model == 'evl':
        model = EVLTransformer()
        checkpoint = torch.load("models/EVL/k400_vitb16_16f_dec4x768.pth")['model']
        new_dict = OrderedDict()
        all_keys = list(checkpoint.keys())

        for key in all_keys:
            if key.startswith('module.'):
                new_dict[key[7:]] = checkpoint[key]
            elif 'cls_head' in key:
                continue
            else:
                new_dict[key] = checkpoint[key]
        checkpoint = new_dict
        model.load_state_dict(checkpoint)
        print('weights loaded!')

    elif args.model == 'st_adapter':
        model = clip_vit_base_patch16_adapter12x384(num_classes=400).float()
        checkpoint = torch.load("models/st_adapter/k400_vit_base_p16_adapt12x384.pth")
        # print(checkpoint.keys())
        checkpoint = checkpoint['model']
        print(checkpoint.keys())
        # print(model.load_state_dict(checkpoint, strict=False))

        print('weights loaded!')

    elif args.model == 'fluxvit':
        cpath = "/lustre/fs1/home/VideoMamba/videomamba/video_sm/models/FluxViT/single_modality/models/fluxvit_s14_k400_ft_upload.pt"
        checkpoint = torch.load(cpath)
        if 'b14' in cpath:
            model = FluxViT(patch_size=14, embed_dim=768,
            img_size=int(args.resolution), num_frames=32, # doesn't matter, just for loading
            depth=12, num_heads=12, mlp_ratio=4,
            attn_pool_num_heads=16, clip_embed_dim=768,
            use_gpe_proj=True, use_lpe=True, dual_norm_in_patch_embed=True)
        else:
            model = FluxViT(patch_size = 14, embed_dim = 384,
            img_size = 252, num_frames = 24,  # doesn't matter, just for loading
            depth = 12, num_heads = 6, mlp_ratio = 4,
            attn_pool_num_heads = 16, clip_embed_dim = 768,
            use_gpe_proj = True, use_lpe = True, dual_norm_in_patch_embed = True,)
        model.load_state_dict(checkpoint)
        print('weights loaded!')

    num_frames = int(args.frames)
    img_size = int(args.resolution)
    model.cuda()
    model.eval()
    # arr =  torch.rand(1, 3, num_frames, img_size, img_size).cuda()
    # s = timeit.default_timer()
    # feat = model(arr)
    # print(timeit.default_timer()-s)


    # flops = FlopCountAnalysis(model, arr)
    # print(flop_count_table(flops, max_depth=1))
    # print(num_frames, img_size)
    
    # exit()

    ######################### LOAD CUSTOM MODELS ####################################
    
    if num_frames < 32 and resolution < 384:
        bs = 4
    elif num_frames == 32 and resolution < 288:
        bs = 2
    else:
        bs = 1
    #bs = 2
    ####################################### Interpolate Positional Emebeddings ###################################
    # if not static_tokens:
    #     orig_size = 14
    #     # height (== width) for the new position embedding
    #     new_size = int(resolution / 16)
    #     num_extra_tokens = 1  # appended cls token  = +1
    #     print(orig_size, new_size)
    #     embedding_size = model.pos_embed.shape[-1]
    #     if orig_size != new_size:
    #         print("Position interpolate from %d to %d" % (orig_size, new_size))
    #         extra_tokens = model.pos_embed[:, :num_extra_tokens]
    #         # only the position tokens are interpolated
    #         pos_tokens = model.pos_embed[:, num_extra_tokens:]
    #         # B, L, C -> B, H, W, C -> B, C, H, W
    #         pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
    #         pos_tokens = torch.nn.functional.interpolate(
    #             pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
    #         # B, C, H, W -> B, H, W, C ->  B, H, W, C
    #         pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_size, new_size, embedding_size)
    #         pos_tokens = pos_tokens.flatten(1, 2)  # B, L, C
    #         new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
    #         model.pos_embed = nn.Parameter(new_pos_embed)

    ####################################### Interpolate Positional Emebeddings ###################################

    model.to('cuda')
    
    
    ######################### Init Data ####################################
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {total_params}")  # 74,503,818
    model.eval()
    print('model loaded, all layers frozen')
    if args.dataset == 'kinetics':
        train_dataset = KineticsDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = KineticsDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'coin':
        train_dataset = COINDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = COINDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'ucf':
        train_dataset = UCFDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = UCFDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'smthsmth':
        train_dataset = SmthSmthDL('train', num_frames=num_frames, flexible=False)
        test_dataset = SmthSmthDL('test', num_frames=num_frames, flexible=False)
    elif args.dataset == 'breakfast':
        train_dataset = BkfstDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = BkfstDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'hmdb':
        train_dataset = HMDBDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = HMDBDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'K700':
        train_dataset = K700('train', num_frames=num_frames, resolution=resolution)
        test_dataset = K700('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'diving':
        train_dataset = DivingDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = DivingDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'SSV2':
        train_dataset = SmthSmthDL('train', num_frames=num_frames, resolution=resolution)
        test_dataset = SmthSmthDL('test', num_frames=num_frames, resolution=resolution)
    elif args.dataset == 'lvu':
        genre = args.genre
        train_dataset = LVUDL('train', num_frames=num_frames, resolution=resolution, genre=genre)
        test_dataset = LVUDL('test', num_frames=num_frames, resolution=resolution, genre=genre)


    test_loader = DataLoader(test_dataset, num_workers=8, batch_size=bs, shuffle=False, collate_fn=multiple_samples_collate, drop_last=True)
    train_loader = DataLoader(train_dataset, num_workers=8, batch_size=bs, shuffle=False, collate_fn=multiple_samples_collate, drop_last=True)
    ######################### Init Data ####################################

    ######################### Eval ####################################

    if args.eval_type == 'train_test':
        run_train_test_ret(model, train_loader, test_loader, args)
        print('frames + reso: ', num_frames, args.resolution, args.dataset, args.model)
        print('weights loaded: ', cpath)
        if 'lvu' in args.dataset:
            print('genre: ', genre)

    else:
        run_ret(model, val_loader=test_loader)

    ######################### Eval ####################################
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Short sample app')
    parser.add_argument('--dataset', choices=['kinetics', 'coin', 'smthsmth', 'ucf', 'breakfast', 'hmdb', 'ntu', 'diving', 'K700', 'SSV2', 'lvu'], required=True)
    parser.add_argument('--eval_type', choices=['zs', 'train_test'], default='train_test', required=False)
    parser.add_argument('--model', choices=['videomamba', 'R3D', 'swin', 'omnivore', 'uniformer', 'hiera', 'aim', 'evl', 'st_adapter', 'mvit', 'flexmvit', 'fluxvit'], required=True)
    parser.add_argument('--frames', required=True)
    parser.add_argument('--resolution', required=True)
    parser.add_argument('--genre', required=False)
    parser.add_argument('--ckpt', required=True)
    parser.add_argument('--train', choices=['finetune', 'probe'], default=False)

    args = parser.parse_args()

    if args.train:
        finetune(args)
    else:
        eval_flexVM(args)