import os
import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
# import scallopy
import sys
import pickle
import gc
import heapq
from datetime import datetime

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../models'))
assert os.path.exists(model_path)
sys.path.append(model_path)

from llava_clip_model_v3 import PredicateModel as PredicateModel_v3
from llava_clip_model_v2 import PredicateModel as PredicateModel_v2
from llava_clip_model import PredicateModel as PredicateModel_v1

def load_model(model_dir, model_folder, model_name, model_epoch, device):
    model_name = model_name + f'.{model_epoch}.model'
    model = torch.load(os.path.join(model_dir, model_folder, model_name), map_location='cuda:'+str(device), weights_only=False)
    return model

def emsemble_model(cate_sd=None, unary_sd=None, binary_sd=None, test_num_top_pairs=0):

    if type(cate_model) == PredicateModel_v3:
        clip_model_name = cate_model.clip_cate_model.config._name_or_path
    elif type(cate_model) == PredicateModel_v1:
         clip_model_name = cate_model.clip_model.config._name_or_path
    else: 
        raise NotImplementedError
    
    new_model = PredicateModel_v3(hidden_dim = 0, num_top_pairs=test_num_top_pairs, device=device, model_name=clip_model_name).to(device)
    if not cate_sd is None:
        new_model.clip_cate_model.load_state_dict(cate_sd)
    if not unary_sd is None:
        new_model.clip_unary_model.load_state_dict(unary_sd)
    if not binary_sd is None:
        new_model.clip_binary_model.load_state_dict(binary_sd)
    
    return new_model
    
if __name__ == "__main__":
    device = 0
    test_num_top_pairs = 0
    model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/LLaVA-Video-178K-v2/models"))
    
    ensemble_model_folder = "ensemble-05-10"
    current_time = datetime.now()
    current_time = current_time.replace(microsecond=0)
    current_time_str = str(current_time).replace(' ', '-').replace(':', '-')
    ensemble_model_name = "vidvrd-ensemble-" + current_time_str + '.0.model'
    ensemble_model_path = os.path.join(model_dir, ensemble_model_folder, ensemble_model_name)
    
    cate_model_folder = "backup-02-09"
    cate_model_name = "laser_clip_LLaVA_2025-01-30-15-33-12_training_100.0_lr_1e-06_fgl_False_negspec_True_ws_True_wns_True_negkw_True_mvl_20_bs_2_ddp_True"
    cate_model_epoch = 0
    cate_model = load_model(model_dir, cate_model_folder, cate_model_name, cate_model_epoch, device)
    if type(cate_model) == PredicateModel_v3:
        cate_state_dict = cate_model.clip_cate_model.state_dict()
    elif type(cate_model) == PredicateModel_v1:
        cate_state_dict = cate_model.clip_model.state_dict()
    
    unary_model_folder = "backup-02-09"
    unary_model_name = "laser_clip_LLaVA_2025-01-30-15-33-12_training_100.0_lr_1e-06_fgl_False_negspec_True_ws_True_wns_True_negkw_True_mvl_20_bs_2_ddp_True"
    unary_model_epoch = 0
    unary_model = load_model(model_dir, unary_model_folder, unary_model_name, unary_model_epoch, device)
    unary_state_dict = unary_model.clip_unary_model.state_dict()
    
    binary_model_folder = "backup-01-15"
    binary_model_name = "laser_clip_LLaVA_2024-11-06-10-53-44_training_1.0_lr_1e-06_fgl_False_negspec_True_ws_True_wns_True_negkw_True_mvl_20_bs_2_ddp_True"
    binary_model_epoch = 1
    binary_model = load_model(model_dir, binary_model_folder, binary_model_name, binary_model_epoch, device)
    if type(binary_model) == PredicateModel_v3:
        binary_state_dict = binary_model.clip_binary_model.state_dict()
    elif type(binary_model) == PredicateModel_v1:
        binary_state_dict = binary_model.clip_model.state_dict()
        
    ensembled_model = emsemble_model(cate_state_dict, unary_state_dict, binary_state_dict, test_num_top_pairs)
    ensembled_model.meta_info = {"cate_model_folder": cate_model_folder, 
                                 "cate_model_name": cate_model_name,
                                 "cate_model_epoch": cate_model_epoch,
                                 "unary_model_folder": unary_model_folder,
                                 "unary_model_name": unary_model_name,
                                 "unary_model_epoch": unary_model_epoch,
                                 "binary_model_folder": binary_model_folder,
                                 "binary_model_name": binary_model_name,
                                 "binary_model_epoch": binary_model_epoch
                                 }
    torch.save(ensembled_model, ensemble_model_path)
    print('here')