import os
import json
import random
import cv2
from argparse import ArgumentParser
import math
from torch import nn, optim
from tqdm import tqdm
import torch
import sys
import json
import gc
import heapq
import setproctitle
import psutil
from datetime import datetime
import matplotlib.colors as colors

from metrics import *

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../models'))
assert os.path.exists(model_path)

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

sys.path.append(model_path)
from llava_clip_model_v3 import PredicateModel


from vidvrd_dataset import *   # Make sure this is correct

def print_time_diff(start_time: datetime, end_time: datetime, message: str = "Elapsed time"):
    """
    Prints the difference between two datetime objects in HH:MM:SS format.

    Parameters:
        start_time (datetime): The starting time.
        end_time (datetime): The ending time.
        message (str): Optional message to display before the time difference.
    """
    # Calculate the difference and convert to total seconds
    delta = end_time - start_time
    total_seconds = delta.total_seconds()
    
    # Convert seconds to hours, minutes, and seconds
    hours, remainder = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    # Format the time difference string
    formatted_diff = f"{int(hours):02}:{int(minutes):02}:{seconds:05.2f}"
    
    # Print the result with the optional message
    print(f"{message}: {formatted_diff}")

def ddp_setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12446"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def parse_args(model_name=None, epoch_num=None):

    # We remove references to "pvsg.json" or "nl2spec" caching,
    # because VidVRD uses a directory structure instead of a single JSON. 
    # This is simplified to highlight the VidVRD usage.
    assert os.path.exists(default_model_dir)
    
    parser = ArgumentParser("Eval Clip")
    parser.add_argument(
        "--dataset",
        type=str,
        default="vidvrd-dataset",               # <-- CHANGED default
        choices=["vidvrd-dataset", "ActionGenome"]
    )
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--load-model", default=True)
    parser.add_argument("--save-model", default=False)
    parser.add_argument("--clip-model-name", type=str, default="openai/clip-vit-base-patch32")

    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=999999)

    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--model-name", type=str, default=model_name)
    parser.add_argument("--model-epoch", type=int, default=epoch_num)

    # Modify or remove references to data_dir if needed
    parser.add_argument("--model-dir", type=str, default=default_model_dir)
    parser.add_argument("--use-cuda", default=True)
    parser.add_argument("--use-half", action="store_true")
    parser.add_argument("--use-ddp", action="store_true")
    parser.add_argument("--use-ensemble",  default= "ensemble" in model_name)
    parser.add_argument("--gpu", type=int, default=-1)
    parser.add_argument("--rel_top_k", type=int, default=1)
    parser.add_argument("--debug", default="all_data") #THIS SHOULD BE THE NAME OF THE DEBUG EXPERIMENT
    
    parser.add_argument("--splice-start", type=int, default=0)
    parser.add_argument("--splice-size", type=int, default=1)
    parser.add_argument("--sgdet",  default=False)

    args = parser.parse_args()
    

    # In VidVRD, the dataset structure is typically something like:
    #  /path/to/vidvrd-dataset
    #    ├── train/*.json, test/*.json
    #    ├── videos/*.mp4
    #    └── info/objects.txt, info/predicates.txt
    #
    # So let's define args.data_dir accordingly:
    
    vidvrd_data_dir = #TODO
    
    args.data_dir = vidvrd_data_dir
    args.video_save_dir = os.path.join(args.data_dir, 'pred_video')

    # The lines below are no longer needed if we do not rely on caching or a single JSON
    # args.data_path = ...
    # args.cache_path = ...
    # args.data_file_name = "pvsg.json"
    
    run_info_dir = #TODO
    args.report_dir = os.path.abspath(
        os.path.join(run_info_dir, f"{'debug/'+args.debug  if args.debug else ''}/reports/{args.dataset}/{args.model_name}.{epoch_num}")
    )
    #TODO: fix
    args.result_dir = os.path.abspath(
        os.path.join(run_info_dir, f"{'debug/'+args.debug  if args.debug else ''}/results/{args.dataset}/{args.model_name}.{epoch_num}")
    )
    args.html_dir = os.path.abspath(
        os.path.join(run_info_dir, f"{'debug/'+args.debug  if args.debug else ''}/visualize/{args.dataset}/{args.model_name}.{epoch_num}")
    )

    if args.result_dir is not None and not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir, exist_ok=True)
    print(args.result_dir)
    if not os.path.exists(args.report_dir):
        os.makedirs(args.report_dir, exist_ok=True)
    print(args.report_dir)
    if not os.path.exists(args.html_dir):
        os.makedirs(args.html_dir, exist_ok=True)
    print(args.html_dir)

    torch.manual_seed(args.seed)
    random.seed(args.seed)

    return args



class Tester():
    def __init__(
        self, 
        test_loader, 
        device,
        dataset,
        model_dir=None, 
        model_name=None,
        model_epoch=None,
        load_model=False, 
        video_save_dir=None,
        test_num_top_pairs=300,
        report_dir=None,
        result_dir=None,
        clip_model_name="google/clip-base-patch16-224",
        use_half=False,
        world_size=1, 
        use_ddp=False
    ):
        
         # Dataset and scallop file setup
        self.dataset = dataset
        self.test_loader = test_loader
        self.device = device
        self.report_dir = report_dir
        self.result_dir = result_dir
        self.model_dir = model_dir
        self.model_name = model_name
        self.world_size = world_size
        self.use_ddp = use_ddp
        
        # Hyperparameter controlling the number of binary pairs to consider for effiency
        self.test_num_top_pairs = test_num_top_pairs
        self.epoch_ct = model_epoch
        
        # Setting up the STSG model
        if load_model and os.path.exists(model_dir) and len(os.listdir(model_dir)) > 0:
            print(f"Loading Model: {model_dir}")
            
            # Load the latest model from given path
            current_model_names = [existing_model_name for existing_model_name in os.listdir(model_dir) if model_name in existing_model_name]
            if '.' in model_name:
                model_ids = [model_name.split('.')[-2] for model_name in current_model_names]
            else:
                model_ids = [model_name for model_name in current_model_names]
            digital_model_ids = [int(model_id) for model_id in model_ids if str.isdigit(model_id)]

            # Default model epoch is the latest one
            if not model_epoch is None:
                latest_model_id = model_epoch
            else:
                if len(digital_model_ids) == 0 and 'latest' in digital_model_ids:
                    latest_model_id = 'latest'
                else:
                    latest_model_id = max(digital_model_ids)

            model_name = model_name + f'.{latest_model_id}.model'
            model_info = torch.load(os.path.join(model_dir, model_name), map_location='cuda:'+str(self.device), weights_only=False)

            if type(model_info) == PredicateModel:
                predicate_model = model_info
            elif type(model_info) == torch.nn.parallel.distributed.DistributedDataParallel:
                predicate_model = model_info.module
            elif type(model_info) == dict:
                predicate_model = PredicateModel(hidden_dim = 0, num_top_pairs=test_num_top_pairs, device=device, model_name=clip_model_name).to(device)
                predicate_model.load_state_dict(model_info['model_state_dict'])

            else:
                predicate_model = PredicateModel(hidden_dim = 0, num_top_pairs=test_num_top_pairs, device=device, model_name=clip_model_name).to(device)
                predicate_model.load_state_dict(model_info)

            predicate_model.use_sparse = False
            predicate_model.device = self.device
            print(f"Loading: {model_name}")
            if type(latest_model_id) == int:
                self.epoch_ct = latest_model_id
        else:
            print("Constructing Model. Checkpoint was not loaded!")
            # Initialize a new predicate model
            predicate_model = PredicateModel(hidden_dim = 0, num_top_pairs=test_num_top_pairs, device=device, model_name=clip_model_name).to(device)

        predicate_model.num_top_pairs = self.test_num_top_pairs
        self.predicate_model = predicate_model

        if use_half:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        else:
            self.loss_fn = nn.BCELoss(reduction='none')

    def get_valid_result(self, frame_ids, probs):
        result = []
        for frame_id, prob in zip(frame_ids, probs):
            assert len(frame_id) == 1
            frame_id = frame_id[0]

            if frame_id == -1:
                continue

            result.append((prob, frame_id))

        result = sorted(result, key=lambda x: x[1])
        return result
    

    @torch.inference_mode()
    def eval_video(self,
                batched_video_ids,
                batched_reshaped_raw_videos,
                batched_object_ids,
                batched_gt_cates,
                batched_gt_masks,
                batched_gt_bboxes,
                batched_gt_object_rels,
                dp_id,
                cate_kw=None,
                unary_kw=[],
                binary_kw=None,
                rel_weights=None,
                recall_thres_ls=[1, 5, 10],
                precision_thres_ls=[1, 5, 10]):
        # Prepare category and binary keywords if not provided
        if cate_kw is None:
            cate_kw = [i.replace("_", " ") for i in self.test_loader.dataset.objects]
        if binary_kw is None:
            binary_kw = [i.replace("_", " ") for i in self.test_loader.dataset.predicates]
        

        # If no object IDs are present, return None
        if len(batched_object_ids) == 0:
            return None

        # Prepare data structures for prediction
        batched_video_splits = [len(batched_reshaped_raw_videos)]
        batched_gt_object_pairs = []
        for video_id, relations in enumerate(batched_gt_object_rels):
            for frame_id, rel_lst in enumerate(relations):
                for (from_id, to_id, rel_name) in rel_lst:
                    batched_gt_object_pairs.append((video_id, frame_id, (from_id, to_id)))
                    

        # Obtain predictions from the model
        batched_image_cate_probs, batched_image_unary_probs, batched_image_binary_probs, dummy_prob = \
            self.predicate_model(
                batched_video_ids=batched_video_ids,
                batched_videos=batched_reshaped_raw_videos,
                batched_masks=batched_gt_masks,  # batched_object_ids * video_height * video_width
                batched_bboxes=batched_gt_bboxes,  # batched_object_ids * dict<bboxes>
                batched_names=[cate_kw],  # Dataset-wise categorical labels
                batched_object_ids=batched_object_ids,  # [video_id, frame_id, object_id]
                batched_unary_kws=[unary_kw],  # Dataset-wise unary predicate labels
                batched_binary_kws=[binary_kw],  # Dataset-wise binary predicate labels
                batched_obj_pairs=batched_gt_object_pairs,  # Ground truth binary relations
                batched_video_splits=batched_video_splits,  # [number of videos]
                batched_binary_predicates=[None],  # None indicates inference time
                multi_class=True
            )
        
        for gt_tps in batched_gt_cates:
            gt_kw = gt_tps[-1]
            gt_kw = gt_kw.replace('/', ' ').replace('_', ' ')
            assert gt_kw in cate_kw, f"{gt_kw} is not in all vocabs"
            
        for vid, vid_gt_object_rels in enumerate(batched_gt_object_rels):
            for fid, gt_object_rels in enumerate(vid_gt_object_rels):
                for gt_tps in gt_object_rels:
                    gt_kw = gt_tps[-1]
                    gt_kw = gt_kw.replace('/', ' ').replace('_', ' ')
                
                    # if not self.dataset == "openpvsg":
                    #     assert gt_kw in binary_kw, f"{gt_kw} is not in all vocabs"
             
        #TODO: fix top_k_classes
        # Compute metrics using the new compute_metrics function
        metrics_res, object_confusion, binary_confusion = compute_metrics_top_k(
            gt_object_dict=batched_gt_cates,
            cate_pred=batched_image_cate_probs[0],
            gt_object_rels=batched_gt_object_rels[0],
            binary_pred=batched_image_binary_probs[0],
            rel_weights=rel_weights,
            top_k_classes=1,
            precision_thres_ls=precision_thres_ls,
            recall_thres_ls=recall_thres_ls,
            all_objects=cate_kw,
            all_predicates=binary_kw
        )

        # Compile the result
        result = {
            "cate": batched_image_cate_probs,
            "unary": batched_image_unary_probs,
            "binary": batched_image_binary_probs,
            "metrics_res": metrics_res
        }
        
        confusion = {
            "object": object_confusion,
            "binary": binary_confusion,
        }
        
        del batched_gt_masks
        gc.collect()
    
        
        return result, confusion

    @torch.inference_mode()
    def eval(self, recall_thres_ls=[1, 5, 10], precision_thres_ls=[1, 5, 10]):
        self.predicate_model.eval()
        self.predicate_model.num_top_pairs = self.test_num_top_pairs

        # We'll store scalar metrics for each video, not per object
        total_metrics_res = {
            'precision': {'cate': {}, 'binary': {}},
            'recall': {'cate': {}, 'binary': {}},
            'processed_vids': []
        }

        # Each threshold will hold a list of per-video numbers
        for thres in precision_thres_ls:
            total_metrics_res['precision']['cate'][thres] = []
            total_metrics_res['precision']['binary'][thres] = []
        for thres in recall_thres_ls:
            total_metrics_res['recall']['cate'][thres] = []
            total_metrics_res['recall']['binary'][thres] = []

        # Create a confusion subdirectory in the results dir if not exists.
        if self.result_dir is not None:
            confusion_dir = os.path.join(self.result_dir, "confusion")
            if not os.path.exists(confusion_dir):
                os.makedirs(confusion_dir)

        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            first_time = datetime.now()
            sample_start = datetime.now()
            for ct, dp_list in enumerate(iterator):
                dp_id = dp_list['batched_ids'][0]
                print(dp_id)
                print(f"[Rank {self.device}] ct={ct}, dp_id={dp_id}")

                if self.result_dir is not None:
                    video_id = dp_id  # assuming batch_size=1 for eval
                    result_path = os.path.join(self.result_dir, f"{video_id}.pkl")
                    confusion_path = os.path.join(confusion_dir, f"{video_id}.npz")

                    # Check if both metrics and confusion are already saved.
                    if os.path.exists(result_path) and os.path.exists(confusion_path):
                        print(f"{video_id} cached. Moving on.")
                        metrics_res = json.load(open(result_path, "r"))
                    else:
                        # Gather ground truth data for this video (as before)
                        batched_gt_cates = list(set([
                            (vid, oid, label)
                            for ((vid, fid, label), (_, _, oid)) in zip(
                                dp_list['batched_gt_obj_names'], dp_list['batched_object_ids']
                            )
                        ]))
                        batched_gt_object_rels = dp_list['batched_gt_object_rels']
                        
                        try:
                            # Note: eval_video now returns (result, confusion)
                            result, confusion = self.eval_video(
                                dp_list['batched_ids'],
                                dp_list['batched_reshaped_raw_videos'],
                                dp_list['batched_object_ids'],
                                batched_gt_cates,
                                dp_list['batched_gt_masks'],
                                dp_list['batched_gt_bboxes'],
                                batched_gt_object_rels,
                                dp_id=dp_id,
                                rel_weights=self.test_loader.dataset.rel_weights,
                                recall_thres_ls=recall_thres_ls,
                                precision_thres_ls=precision_thres_ls,
                            )
                        except Exception as e:
                            sample_start = datetime.now()
                            print()
                            print(f"Error processing {dp_id}: {e}")
                            continue

                        if result is None:
                            sample_start = datetime.now()
                            print(f"Error processing {dp_id}. None. Is video already done?")
                            continue
                        
                        #timing
                        print_time_diff(sample_start, datetime.now(), f"Sample timing {ct}")
                        print_time_diff(first_time, datetime.now(), "All timing")
                        
                        metrics_res = result["metrics_res"]
                        if self.result_dir is not None and len(dp_list['batched_object_ids']) > 0:
                            with open(result_path, "w") as f:
                                json.dump(metrics_res, f)
                            # Save confusion matrices in a bundled npz file.
                            np.savez(confusion_path,
                                    object=confusion["object"],
                                    binary=confusion["binary"])
                        total_metrics_res['processed_vids'].append(dp_id)
                        # Clean up
                        del result
                        del confusion
                        sample_start = datetime.now()

                
                self.finalize_cached_results()
                del dp_list['batched_gt_masks']
                del dp_list
                gc.collect()
                torch.cuda.empty_cache()


                    
    def finalize_cached_results(self, recall_thres_ls=[1, 5, 10], precision_thres_ls=[1, 5, 10], map_color = "seismic"):
        """
        Go through all cached result files (one per dp_id) without re-inferencing,
        compute the final metrics, and dump the final report.
        This method uses self.test_loader.samples to get all dp_ids and tries
        to open <dp_id>.pkl in self.result_dir to retrieve the cached metrics.
        """
        import os
        import json
        import numpy as np
        import matplotlib.pyplot as plt

        # Prepare the structure to gather metrics
        gathered_results = {
            'precision': {'cate': {}, 'binary': {}},
            'recall': {'cate': {}, 'binary': {}},
            'processed_vids': []
        }

        # Initialize empty lists for each threshold
        for thres in precision_thres_ls:
            gathered_results['precision']['cate'][thres] = []
            gathered_results['precision']['binary'][thres] = []
        for thres in recall_thres_ls:
            gathered_results['recall']['cate'][thres] = []
            gathered_results['recall']['binary'][thres] = []

        # Iterate over dp_ids from test_loader.samples
        for dp_id in self.test_loader.dataset.samples_all:
            result_path = os.path.join(self.result_dir, f"{dp_id[:-5]}.pkl")
            if not os.path.exists(result_path):
                continue

            # Load the cached metrics
            with open(result_path, "r") as f:
                metrics_res = json.load(f)  # Should match the structure from compute_metrics

            # Accumulate per-video scalar metrics
            for metric in ['precision', 'recall']:
                if metric == 'precision':
                    thres_list = precision_thres_ls
                else:
                    thres_list = recall_thres_ls

                for thres in thres_list:
                    # Categories
                    cate_val = metrics_res[metric]['cate'].get(str(thres), None)
                    if cate_val is None:
                        cate_val = metrics_res[metric]['cate'].get(thres, None)
                    gathered_results[metric]['cate'][thres].append(cate_val)

                    # Binary relationships
                    binary_val = metrics_res[metric]['binary'].get(str(thres), None)
                    if binary_val is None:
                        binary_val = metrics_res[metric]['binary'].get(thres, None)
                    gathered_results[metric]['binary'][thres].append(binary_val)

            gathered_results['processed_vids'].append(dp_id)

        # Prepare the final report structure
        report = {
            'epoch_ct': self.epoch_ct,
            'processed_vids': gathered_results['processed_vids'],
            'precision': {'cate': {}, 'binary': {}},
            'recall': {'cate': {}, 'binary': {}},
        }

        # Compute averages across the gathered metrics
        for metric in ['precision', 'recall']:
            if metric == 'precision':
                thres_list = precision_thres_ls
            else:
                thres_list = recall_thres_ls

            for thres in thres_list:
                # Categories
                cate_values = gathered_results[metric]['cate'][thres]
                report[metric]['cate'][thres] = (
                    sum(cate_values) / len(cate_values) if len(cate_values) > 0 else 0.0
                )
                # Binary
                binary_values = gathered_results[metric]['binary'][thres]
                report[metric]['binary'][thres] = (
                    sum(binary_values) / len(binary_values) if len(binary_values) > 0 else 0.0
                )

        # Save the final report (if self.report_dir is defined)
        if self.report_dir is not None:
            report_path = os.path.join(
                self.report_dir, f"{self.model_name}.{self.epoch_ct}.metrics_report.txt"
            )
            with open(report_path, 'w') as file:
                json.dump(report, file, indent=2)
            # print(f"Final cached report saved to: {report_path}")

        # ------------------------------------------------------------------------
        # Aggregate per-video confusion matrices (object and binary) and visualize
        # ------------------------------------------------------------------------
        # Assumes that confusion matrices are saved in self.result_dir/confusion as <dp_id>.npz
        confusion_dir = os.path.join(self.result_dir, "confusion")
        aggregated_object_confusion = None
        aggregated_binary_confusion = None
        num_matrices = 0

        for dp_id in self.test_loader.dataset.samples_all:
            confusion_file = os.path.join(confusion_dir, f"{dp_id[:-5]}.npz")
            if not os.path.exists(confusion_file):
                continue
            data = np.load(confusion_file)
            obj_conf = data["object"]
            bin_conf = data["binary"]

            if aggregated_object_confusion is None:
                aggregated_object_confusion = obj_conf.copy()
            else:
                aggregated_object_confusion += obj_conf

            if aggregated_binary_confusion is None:
                aggregated_binary_confusion = bin_conf.copy()
            else:
                aggregated_binary_confusion += bin_conf
            num_matrices += 1

        # Create visualizations if any confusion matrices were aggregated
        if self.report_dir is not None:
            # --- Averaged Object Confusion Matrix ---
            if aggregated_object_confusion is not None:
                # Average the aggregated data
                averaged_object_confusion = aggregated_object_confusion / num_matrices

                fig, ax = plt.subplots(figsize=(10, 8))
                
                # Apply logarithmic scaling on the averaged data
                norm = colors.LogNorm(
                    vmin=averaged_object_confusion[averaged_object_confusion > 0].min(),
                    vmax=averaged_object_confusion.max()
                )
                
                im = ax.imshow(averaged_object_confusion, interpolation='nearest', cmap=map_color, norm=norm)
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
                
                # Use dataset objects for labeling and clean underscores
                object_labels = [i.replace("_", " ") for i in self.test_loader.dataset.objects]
                ax.set_xticks(np.arange(len(object_labels)))
                ax.set_yticks(np.arange(len(object_labels)))
                ax.set_xticklabels(object_labels, rotation=45, ha="right", fontsize=8)
                ax.set_yticklabels(object_labels, fontsize=8)
                ax.set_xlabel("Predicted", fontsize=10)
                ax.set_ylabel("Ground Truth", fontsize=10)
                ax.set_title("Averaged Object Confusion Matrix", fontsize=12)
                
                # Add gridlines using minor ticks for cell boundaries
                ax.set_xticks(np.arange(-0.5, len(object_labels), 1), minor=True)
                ax.set_yticks(np.arange(-0.5, len(object_labels), 1), minor=True)
                ax.grid(which="minor", color="gray", linestyle='-', linewidth=0.5)
                ax.tick_params(which="minor", bottom=False, left=False)
                
                object_conf_path = os.path.join(self.report_dir, "averaged_object_confusion.png")
                plt.tight_layout()
                plt.savefig(object_conf_path, dpi=300)
                plt.close(fig)

            # --- Averaged Binary Confusion Matrix ---
            if aggregated_binary_confusion is not None:
                # Average the aggregated data
                averaged_binary_confusion = aggregated_binary_confusion / num_matrices

                fig, ax = plt.subplots(figsize=(20, 20))
                
                # Apply logarithmic scaling on the averaged data
                norm = colors.LogNorm(
                    vmin=averaged_binary_confusion[averaged_binary_confusion > 0].min(),
                    vmax=averaged_binary_confusion.max()
                )
                
                im = ax.imshow(averaged_binary_confusion, interpolation='nearest', cmap=map_color, norm=norm)
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
                
                # Use dataset predicates for labeling and clean underscores
                binary_labels = [i.replace("_", " ") for i in self.test_loader.dataset.predicates]
                ax.set_xticks(np.arange(len(binary_labels)))
                ax.set_yticks(np.arange(len(binary_labels)))
                ax.set_xticklabels(binary_labels, rotation=90, fontsize=6)
                ax.set_yticklabels(binary_labels, fontsize=6)
                
                # Add faint gridlines: minor ticks at cell boundaries
                ax.set_xticks(np.arange(-0.5, len(binary_labels), 1), minor=True)
                ax.set_yticks(np.arange(-0.5, len(binary_labels), 1), minor=True)
                ax.grid(which="minor", color="gray", linestyle='-', linewidth=0.5)
                ax.tick_params(which="minor", bottom=False, left=False)
                
                ax.set_xlabel("Predicted", fontsize=10)
                ax.set_ylabel("Ground Truth", fontsize=10)
                ax.set_title("Averaged Binary Confusion Matrix", fontsize=12)
                
                binary_conf_path = os.path.join(self.report_dir, "averaged_binary_confusion.png")
                plt.tight_layout()
                plt.savefig(binary_conf_path, dpi=300)
                plt.close(fig)



from torch.utils.data.distributed import DistributedSampler

def main(rank: int, world_size: int, args):
    # 1. Initialize DDP if requested
    if args.use_ddp:
        ddp_setup(rank, world_size)
    device = rank

    # 2. Build dataset (train, valid, etc.) in your normal way
    #    We only care about the test dataset right now.
    supported_datasets = {
        "vidvrd-dataset": open_vidvrd_loader,
    } 
    
    if args.use_ddp:
        test_sampler = DistributedSampler
    else:
        test_sampler = None


    data_args = {
        "dataset_dir": args.data_dir,
        "batch_size": args.batch_size,
        "device": device,
        "training_percentage": 1,
        "testing_percentage": args.test_percentage,
        "max_video_len": args.max_video_len,
        "neg_kws": False,
        "neg_spec": False,
        "neg_example_ct": 0,
        "neg_example_file_name": "neg_examples.json",
        "backbone_model": "clip",
        "sampler": test_sampler,
        "splice_start": args.splice_start,
        "splice_size": args.splice_size,
    }
    train_dataset, valid_dataset, test_dataset, train_loader, valid_loader, test_loader = supported_datasets[args.dataset](
        **data_args
    )

    # 5. Print debug info
    #    len(test_loader) = (#samples // batch_size) for that rank
    #    If your dataset has 200 total samples for test, and world_size=8, each rank gets 25
    #    => each rank sees 25 samples => if batch_size=1 => 25 batches.
    print(f"[Rank {rank}] has {len(test_loader)} batches total.")
    print(f"[Rank {rank}] sees {len(test_loader.dataset)} samples overall (not just for this rank).")

    # 6. Construct the Tester using the new test_loader
    trainer = Tester(
        test_loader=test_loader if args.phase == "test"  or args.phase == "cache_test" else valid_loader,
        device=device,
        dataset=args.dataset,
        model_dir=args.model_dir,
        model_name=args.model_name,
        model_epoch=args.model_epoch,
        load_model=args.load_model,
        video_save_dir=args.video_save_dir,
        test_num_top_pairs=args.test_num_top_pairs,
        report_dir=args.report_dir,
        result_dir=args.result_dir,
        clip_model_name=args.clip_model_name,
        use_half=args.use_half,
        world_size=world_size, 
        use_ddp=args.use_ddp,
    )

    # 7. Run evaluation or training
    if args.phase == "eval":
        print("our eval")
        trainer.eval()
    elif args.phase == "test":
        print("baseline eval")
        trainer.eval()
    elif args.phase == "cache_test":
        print("eval for cache")
        trainer.finalize_cached_results()
    
    if args.use_ddp:
        destroy_process_group()


if __name__ == "__main__":
    setproctitle.setproctitle("VidVRD SGCLS Run")
    torch.multiprocessing.set_start_method('spawn', force=True)
    model_name = "laser_clip_LLaVA_2025-01-19-17-12-44_training_100.0_lr_1e-06_fgl_False_negspec_True_ws_True_wns_True_negkw_True_mvl_20_bs_2_ddp_True"
    epoch_num = 1
    world_size = torch.cuda.device_count()
    args = parse_args(model_name, epoch_num)
        
    if args.use_ddp:
        mp.spawn(main, args=(world_size, args), nprocs=world_size)
    else:
        main(0, world_size, args)

    print(args.model_name)
    print("end")
