import os

import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
import scallopy
import sys
import pickle
import gc
import heapq

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../models'))
assert os.path.exists(model_path)

sys.path.append(model_path)
from open_pvsg_predicate_model_gt_rel import PredicateModel
from openpvsg_dataset import *
from utils import *

def recall(gt_object_dict, cate_pred, gt_object_rels, binary_pred, recall_thres_ls = [1, 5, 10]):
    result_unary = {}
    result_binary = {}
    new_cate_pred = {}
    for (p, (oid, name)) in cate_pred:
        if not oid in new_cate_pred:
            new_cate_pred[oid] = []
        new_cate_pred[oid].append((p, name))
        
    for recall_thres in recall_thres_ls:
        result_unary[recall_thres] = []
        result_binary[recall_thres] = []
    
    for vid, oid, gt_label in gt_object_dict:
        assert vid == 0
        
        pred = new_cate_pred[oid]
        sorted_pred = sorted(pred, reverse=True) 
        
        for recall_thres in recall_thres_ls:
            top_pred_ls = [name for p, name in sorted_pred[:recall_thres]]
            if gt_label in top_pred_ls:
                result_unary[recall_thres].append(1)
            else:
                result_unary[recall_thres].append(0)
     
    gt_rel_dict = {}
    for fid, rel_ls in enumerate(gt_object_rels):
        for (from_id, to_id, rel) in rel_ls:
            gt_rel_dict[(fid, from_id, to_id)] = rel
    
    new_binary_pred = {}
    for (p, (rel, fid, from_id, to_id)) in binary_pred:
        if not (fid, from_id, to_id) in new_binary_pred:
            new_binary_pred[(fid, from_id, to_id)] = []
        new_binary_pred[(fid, from_id, to_id)].append((p, rel))
    
    for (fid, from_id, to_id), gt_label in gt_rel_dict.items():
        if not (fid, from_id, to_id) in new_binary_pred:
            binary_pred_ls = []
        else:
            binary_pred_ls = sorted(new_binary_pred[(fid, from_id, to_id)], reverse=True)

        for recall_thres in recall_thres_ls:
            top_pred_ls = [bid for p, bid in binary_pred_ls[:recall_thres]]
            
            if gt_label in top_pred_ls:
                result_binary[recall_thres].append(1)
            else:
                result_binary[recall_thres].append(0)
    
    return {"cate": result_unary, "binary": result_binary}
    
class Trainer():
    def __init__(self, train_loader, test_loader, device,
                 caption2scl, common_scl_path,
                 latent_dim = 64,
                 provenance="difftopkproofs", k=3, save_scl_dir=None,
                 use_neg_spec=False, use_neg_kws=False,
                 model_dir=None, model_name=None, 
                 learning_rate=None,
                 load_model=False, save_model=True,
                 video_save_dir=None,
                 train_num_top_pairs=100, 
                 test_num_top_pairs=300, 
                 report_dir=None,
                 result_dir=None,
                 neg_spec_weight=0.1,
                 neg_entity_kw_cate_weight=0,
                 neg_entity_kw_binary_weight=0.1,
                 siglip_model_name="google/siglip-base-patch16-224", 
                 use_half=True):

         # Dataset and scallop file setup
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.caption2scl = caption2scl
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.save_model = save_model
        self.report_dir = report_dir
        self.result_dir = result_dir
        self.model_dir = model_dir
        self.model_name = model_name
        
        # Contrastive learning type
        self.use_neg_spec = use_neg_spec
        self.use_neg_kws = use_neg_kws
        self.neg_spec_weight = neg_spec_weight
        self.neg_entity_kw_cate_weight = neg_entity_kw_cate_weight 
        self.neg_entity_kw_binary_weight = neg_entity_kw_binary_weight
        
        # Hyperparameter controlling the number of binary pairs to consider for effiency
        self.train_num_top_pairs = train_num_top_pairs
        self.test_num_top_pairs = test_num_top_pairs

        # Scallop context and forwarding setup
        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_gpt_prog_str_preds)
        
        self.reason = self.scallop_ctx.forward_function(output_mappings={
            "aligned_t1": None,
            "aligned_t2": None, 
            "aligned_t3": None,
        # }, dispatch="single")
        }, retain_graph=True)
        
        # Training continuation setups
        self.epoch_ct = 0
        
        # Setting up the STSG model
        if load_model and os.path.exists(model_dir) and len(os.listdir(model_dir)) > 0:
            print(f"Loading Model: {model_dir}")
            # Load the latest model from given path
            # Load the latest model from given path
            current_model_names = [existing_model_name for existing_model_name in os.listdir(model_dir) if model_name in existing_model_name]
            model_ids = [model_name.split('.')[-2] for model_name in current_model_names]
            digital_model_ids = [int(model_id) for model_id in model_ids if str.isdigit(model_id)]
            
            if len(digital_model_ids) == 0 and 'latest' in digital_model_ids:
                latest_model_id = 'latest'
            else:
                latest_model_id = max(digital_model_ids)
                
            model_name = model_name + f'.{latest_model_id}.model'
            
            model_info = torch.load(os.path.join(model_dir, model_name), map_location="cuda:" + str(self.device))
            
            if type(model_info) == PredicateModel:
                predicate_model = model_info
            elif type(model_info) == torch.nn.parallel.distributed.DistributedDataParallel:
                predicate_model = model_info.module
            else:
                predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=train_num_top_pairs, device=device, model_name=siglip_model_name, use_sparse=args.use_sparse).to(device)
                predicate_model.load_state_dict(model_info)
             
            predicate_model.use_sparse = False
            predicate_model.device = self.device
            print(f"Loading: {model_name}")
            if type(latest_model_id) == int:
                self.epoch_ct = latest_model_id
        else:
            print("Constructing Model")
            # Initialize a new predicate model
            predicate_model = PredicateModel(hidden_dim = latent_dim, num_top_pairs=train_num_top_pairs, device=device, model_name=siglip_model_name).to(device)
        
        predicate_model.num_top_pairs = self.train_num_top_pairs
        self.predicate_model = predicate_model
        
        # Setting up learning parameters
        self.optimizer = optim.Adam(self.predicate_model.parameters(), lr=learning_rate)
        self.min_loss = 10000000000
        self.use_half = use_half
        
        if use_half:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        else:
            self.loss_fn = nn.BCELoss(reduction='none')
            
        # self.scaler = torch.amp.GradScaler()
        
        # Debugging utils
        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
            
    def get_valid_result(self, frame_ids, probs):
        result = []
        for frame_id, prob in zip(frame_ids, probs):
            assert len(frame_id) == 1
            frame_id = frame_id[0]
            
            if frame_id == -1:
                continue
            
            result.append((prob, frame_id))
                
        result = sorted(result, key=lambda x: x[1])
        return result
    
    def eval_video(self, 
                    batched_reshaped_raw_videos, 
                    batched_object_ids,
                    batched_gt_cates,
                    batched_gt_bboxes,
                    batched_gt_object_rels,
                    cate_kw = all_entities, unary_kw = [], 
                    binary_kw = all_binary_preds, 
                    recall_thres_ls = [1, 5, 10]):
        
        
        batched_video_splits = [batched_reshaped_raw_videos.shape[0]]
        if len(batched_object_ids) == 0:
            return {}, {}
        
        # Obtain the predictions with pretrained model
        batched_image_unary_probs, batched_image_binary_probs, selected_pairs  = \
            self.predicate_model(batched_videos=batched_reshaped_raw_videos,
                                batched_bboxes=batched_gt_bboxes, 
                                batched_names=[cate_kw],
                                batched_object_ids=batched_object_ids,
                                batched_unary_kws=[unary_kw],
                                batched_binary_kws=[binary_kw],
                                batched_gt_obj_pairs=batched_gt_object_rels, 
                                batched_video_splits=batched_video_splits)
        
        # Only categories live in the ground truth scene graph
        # Process the categories as predicates
        
        # all_binary_preds
        for vid, (image_cate_probs, image_unary_probs) in enumerate(batched_image_unary_probs):
            
            cate_pred_scl = {}
            
            object_ids = [(fid, oid) for (ovid, fid, oid) in batched_object_ids if vid == ovid]
            assert image_cate_probs[0].shape[0] == 0 or len(object_ids) == image_cate_probs[0].shape[1]
                
            for cate_name, probs in zip(image_cate_probs[1], image_cate_probs[0]):
                for prob, (fid, oid) in zip(probs, object_ids):
                    if not oid in cate_pred_scl:
                        cate_pred_scl[oid] = {}
                    if not cate_name in cate_pred_scl[oid]:
                        cate_pred_scl[oid][cate_name] = []
                    cate_pred_scl[oid][cate_name].append(prob)
                
            new_cate_pred_scl = []
            for oid, object_cate_info in cate_pred_scl.items():
                for cate_name, prob in object_cate_info.items():
                    new_cate_pred_scl.append((torch.mean(torch.stack(prob)).item(), (oid, cate_name)))
            
            unary_pred_scl = []
            for unary_name, probs in zip(image_unary_probs[1], image_unary_probs[0]):
                for prob, (fid, oid) in zip(probs, object_ids):
                    unary_pred_scl.append(prob.item(), (oid, fid, unary_name))
                                
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, image_binary_probs in enumerate(batched_image_binary_probs):
            binary_pred_scl = []
            
            if len(image_binary_probs) == 0:
                batched_binary_pred_scl.append([])
                continue
            
            object_pairs = [ (fid, pair) for (ovid, fid, pair) in selected_pairs if vid == ovid]
            assert (len(object_pairs) == 0 and image_binary_probs[0].shape[0]== 0) or len(object_pairs) == image_binary_probs[0].shape[1]
            
            for binary_pred_name, probs in zip(image_binary_probs[1], image_binary_probs[0]):
 
                for prob, (fid, pair) in zip(probs, object_pairs):
                    binary_pred_scl.append((prob.item(), (binary_pred_name, fid, pair[0], pair[1])))
        
        
        recall_res = recall(gt_object_dict=batched_gt_cates, 
                            cate_pred=new_cate_pred_scl,
                            gt_object_rels=batched_gt_object_rels[0], 
                            binary_pred=binary_pred_scl, 
                            recall_thres_ls=recall_thres_ls)
        
        result = {"cate": new_cate_pred_scl, "unary": unary_pred_scl, "binary": binary_pred_scl, "recall_res": recall_res}
      
        return result
    
    def eval(self, recall_thres_ls = [1, 5, 10]):
        self.predicate_model.eval()
        self.predicate_model.num_top_pairs = self.test_num_top_pairs
        total_recall_res = {}
        total_recall_res['cate'] = {}
        total_recall_res['binary'] = {}
        for recall_thres in recall_thres_ls:
            total_recall_res['cate'][recall_thres] = []
            total_recall_res['binary'][recall_thres] = []
            
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            processed_vids = []
            for ct, dp_list in enumerate(iterator):
                
                result = {}
                dp_id = dp_list['batched_ids'][0]
                
                # if not self.result_dir is None:
                #     assert self.model_name in self.result_dir
                    
                #     result_path = os.path.join(self.result_dir, f"{dp_id}.json")
                    
                #     if os.path.exists(result_path):
                #         try:
                #             result = json.load(open(result_path, 'r'))
                #             processed_vids.append(dp_id)
                #         except:
                #             result = {}
                            
                if len(result.keys()) == 0:
                            
                    batched_reshaped_raw_videos = dp_list['batched_reshaped_raw_videos']
                    batched_gt_bboxes = dp_list['batched_gt_bboxes'] 
                    
                    batched_gt_cates = list(set([(vid, oid, label)  for ((vid, fid, label), (_, _, oid)) in zip(dp_list['batched_gt_obj_names'], dp_list['batched_object_ids'])]))

                    batched_gt_object_rels = dp_list['batched_gt_object_rels']
                    batched_object_ids = dp_list['batched_object_ids']
                    
                    result['id'] = dp_list['batched_ids'][0]
                    result['caption'] = dp_list['batched_captions'][0]
                    result['video'] = dp_list['batched_reshaped_raw_videos'].tolist()
                
                    result.update(self.eval_video(
                        batched_reshaped_raw_videos, 
                        batched_object_ids,
                        batched_gt_cates,
                        batched_gt_bboxes,
                        batched_gt_object_rels, 
                        recall_thres_ls=recall_thres_ls,
                        ))
                    
                    processed_vids.append(dp_id)
                    
                result['bbox'] = dp_list['batched_gt_bboxes']
                result['obj_pairs'] = dp_list['batched_gt_object_rels'][0]
                result['obj_ids'] = dp_list['batched_object_ids']
                
                for recall_thres in recall_thres_ls:
                    if type(list(result['recall_res']['cate'].keys())[0]) == str:
                        cate_res = result['recall_res']['cate'][str(recall_thres)]
                        binary_res = result['recall_res']['binary'][str(recall_thres)]
                    else:
                        cate_res = result['recall_res']['cate'][recall_thres]
                        binary_res = result['recall_res']['binary'][recall_thres]

                    total_recall_res['cate'][recall_thres] += cate_res
                    total_recall_res['binary'][recall_thres] += binary_res
                
                result_save = {}
                result_save['recall_res'] = result['recall_res']
                result_save['cate'] = result['cate']
                result_save['unary'] = result['unary']
                result_save['binary'] = result['binary']
                
                report = {}
                report['binary'] = {}
                report['unary'] = {}
                report['processed_vids'] = processed_vids
                
                for recall_thres in recall_thres_ls:
                    if len(total_recall_res['cate'][recall_thres]) == 0:
                        report['unary'][recall_thres] = 0
                    else:
                        report['unary'][recall_thres] = sum(total_recall_res['cate'][recall_thres]) / len(total_recall_res['cate'][recall_thres])
                    
                    if len(total_recall_res['binary'][recall_thres]) == 0:
                        report['binary'][recall_thres] = 0
                    else: 
                        report['binary'][recall_thres] = sum(total_recall_res['binary'][recall_thres]) / len(total_recall_res['binary'][recall_thres])
            
                if not self.result_dir is None:
                    assert self.model_name in self.result_dir
                    
                    if not os.path.exists(self.result_dir):
                        os.mkdir(self.result_dir)
                        
                    result_path = os.path.join(self.result_dir, f"{dp_list['batched_ids'][0]}.json")
                    
                    with open(result_path, 'w') as file:
                        json.dump(result_save, file)

                            
                if not self.report_dir is None:
                    report_path = os.path.join(self.report_dir, f'{self.model_name}.recall_report.txt')
                    with open(report_path, 'w') as file:
                        file.write(str(report))

                del dp_list
            
            
    def save_scl_file(self, datapoint, object_tps, current_constraint):
        scl_file_content = obtain_scl_file(object_tps, current_constraint, self.common_scl)
        scl_file_name = datapoint['id'] + '.scl'
        if not self.save_scl_dir is None:
            scl_path = os.path.join(self.save_scl_dir, scl_file_name)
            with open(scl_path, 'w') as scl_file:
                scl_file.write(scl_file_content)

if __name__ == "__main__":
    
    # Set up data directories and paths
    dataset = "open_pvsg"
    # cache_file_name = f"gpt_specs_videollamav2_prog_str.json"
    cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    assert os.path.exists(data_dir)
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    caption2scl = json.load(open(cache_path, 'r'))
    
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    common_scl_path = os.path.join(scl_dir, f'ltl_disj.scl')
    assert os.path.exists(common_scl_path)
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    # Setup argument parser
    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load-model", default=False)
    parser.add_argument("--save-model", default=False)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--use-neg-spec",  type=bool, default=True)
    parser.add_argument("--use-neg-kws", type=bool, default=True)
    parser.add_argument("--neg-example-ct", type=int, default=2)
    parser.add_argument("--neg-spec-weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_binary_weight", type=float, default=0.1)
    parser.add_argument("--neg_entity_kw_cate_weight", type=float, default=0)
    parser.add_argument("--siglip-model-name", type=str, default="google/siglip-base-patch16-224")

    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--report-dir", type=str, default=os.path.join(data_dir, "eval_report_recall"))
    parser.add_argument("--result-dir", type=str, default=os.path.join(data_dir, "eval_results"))
    
    parser.add_argument("--train-num-top-pairs", type=int, default=10)
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=12)
    
    # setup question path
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=100)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_false")
    parser.add_argument("--use-half", action="store_true")
    parser.add_argument("--gpu", type=int, default=-1)
    
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    assert args.batch_size == 1

    # model_name = f"laser_vllama_{dataset}_training_{args.train_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_negspec_{args.use_neg_spec}_negkw_{args.use_neg_kws}_specweight_{args.neg_spec_weight}_cateweight{args.neg_entity_kw_cate_weight}_binweight{args.neg_entity_kw_binary_weight}"
    # model_name = "laser_open_pvsg_training_10_lr_1e-05_model_siglip-base-patch16-256_negspec_True_kwweight_0.1_negkw_True_negcate_0_negbin_0.1_mvl_8_seed_1234_batch_size_1_prov_difftopkproofs_tpk_3"
    # model_name = "laser_vllama_violet_open_pvsg_training_100_seed_1234_batch_size_1_lr_1e-07_prov_difftopkproofs_tpk_3_negspec_True_negkw_True_kwweight_0.1_mvl_8"
    record_siglip_model_name = args.siglip_model_name.split('/')[-1]
    model_name = f"laser_{dataset}" + \
                  f"_training_{args.train_percentage}" +\
                  f"_lr_{args.learning_rate}" + \
                  f"_model_{record_siglip_model_name}" + \
                  f"_negspec_{args.use_neg_spec}" + \
                  f"_kwweight_{args.neg_spec_weight}" + \
                  f"_negkw_{args.use_neg_kws}" + \
                  f"_negcate_{args.neg_entity_kw_cate_weight}" + \
                  f"_negbin_{args.neg_entity_kw_binary_weight}" + \
                  f"_mvl_{args.max_video_len}" + \
                  f"_seed_{args.seed}" + \
                  f"_batch_size_{args.batch_size}" + \
                  f"_prov_{args.provenance}" + \
                  f"_tpk_{args.train_top_k}"
    
    # model_name = "laser_vllama_violet_open_pvsg_training_100_seed_1234_batch_size_1_lr_1e-06_prov_difftopkproofs_tpk_3_negspec_True_negkw_True_kwweight_0.1_mvl_8"   
    # model_name = "laser_open_pvsg_training_10_lr_1e-05_model_siglip-base-patch16-224_negspec_True_kwweight_0.1_negkw_True_negcate_0_negbin_0.1_mvl_8_seed_1234_batch_size_1_prov_difftopkproofs_tpk_3"
    # model_name = "laser_vllama_open_pvsg_training_10_seed_1234_batch_size_1_lr_1e-06_prov_difftopkproofs_tpk_3_negspec_True_negkw_True_specweight_0.1_cateweight0_binweight0.1"
    model_name = "origin_siglip"
    
    if args.model_name is None:
        args.model_name = model_name

    device = 0 if torch.cuda.is_available() else "cpu"
    
    with open(data_path, 'r') as f:
        anno = json.load(f)
    
    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)
    if not os.path.exists(args.report_dir):
        os.mkdir(args.report_dir)  
        
    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
            cache_path=cache_path,
            dataset_dir=data_dir, 
            dataset_name=data_file_name, 
            batch_size=args.batch_size, 
            device=device, 
            training_percentage=args.train_percentage, 
            testing_percentage=args.test_percentage, 
            max_video_len=args.max_video_len,
            neg_kws=args.use_neg_kws,
            neg_spec=args.use_neg_spec,
            neg_example_ct=args.neg_example_ct,
            neg_example_file_name="neg_examples.json",
            backbone_model="siglip"
            )
    
    trainer = Trainer(train_loader=train_loader,
                      test_loader=test_loader, 
                      device=device, caption2scl=caption2scl, 
                      save_scl_dir=scl_dir, common_scl_path=common_scl_path,
                      latent_dim=args.latent_dim,
                      model_dir=args.model_dir, model_name=args.model_name,
                      learning_rate=args.learning_rate, load_model=args.load_model,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      train_num_top_pairs=args.train_num_top_pairs,
                      test_num_top_pairs=args.test_num_top_pairs,
                      report_dir=args.report_dir,
                      result_dir=os.path.join(args.result_dir, model_name),
                      use_neg_kws=args.use_neg_kws,
                      use_neg_spec=args.use_neg_spec,
                      neg_spec_weight=args.neg_spec_weight,
                      neg_entity_kw_cate_weight=args.neg_entity_kw_cate_weight,
                      neg_entity_kw_binary_weight=args.neg_entity_kw_binary_weight,
                      siglip_model_name=args.siglip_model_name,
                      use_half=args.use_half,
                      )
    
    print(args.model_name)

    print("baseline eval")
    trainer.eval()
    
    print("end")