import argparse
import sys
import os
import torch
import torchvision
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
import numpy as np

# 1) The directory this file lives in:
here = os.path.dirname(__file__)  
project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
sys.path.append(project_root)

from BackdoorObjectDetection.bd_dataset.wrapper import BDWrapper
from BackdoorObjectDetection.bd_models.utils.load_utils import train_collate_fn, test_collate_fn
from BackdoorObjectDetection.bd_models.models.build import build_model
from BackdoorObjectDetection.bd_dataset.dataset.ptsd import PTSDWrapper
from BackdoorObjectDetection.bd_dataset.dataset.ptsd_meta import PTSDMetaWrapper

NUM_CLASSES = 5 

def plot_ptsd_video(path, data, folder, trigger_type):
    # 1) Gather and sort all frame identifiers
    frames = sorted({item['frame'] for item in data if 'frame' in item})
    num_frames = len(frames)
    gt_class = data[0].get('gt_class', 0)   # ground‐truth class
    bd_class = data[0].get('bd_class', 0)   # backdoor class

    # 2) Prepare a (classes × frames) zero matrix
    M = np.zeros((NUM_CLASSES, num_frames), dtype=float)

    # 3) Initialize counts
    asr_count = 0
    ra_count  = 0

    for item in data:
        # map frame to column index
        f_idx = frames.index(item['frame'])
        labels = item.get('pred_labels', [])
        scores = item.get('pred_scores', [])

        # if no detections, count it as a single background (class 0) hit
        if not labels:
            M[0, f_idx] += 1

        # accumulate per‐label scores into heatmap matrix
        for lbl, score in zip(labels, scores):
            if 0 <= lbl < NUM_CLASSES:
                M[lbl, f_idx] += score
            else:
                raise ValueError(f"Label index {lbl} out of bounds for NUM_CLASSES={NUM_CLASSES}")

        # --- attack‐type logic ---
        if bd_class != 0:
            # RMA: ASR++ if we saw the backdoor label; RA++ if we saw the GT label
            if bd_class in labels:
                asr_count += 1
            if gt_class in labels:
                ra_count += 1
        else:
            # ODR: RA++ if GT detected; else ASR++
            if gt_class in labels:
                ra_count += 1
            else:
                asr_count += 1

    # sanity checks
    if not (0 <= asr_count <= num_frames):
        raise ValueError(f"[TestModel] ASR count {asr_count} out of bounds (0–{num_frames})")
    if not (0 <= ra_count <= num_frames):
        raise ValueError(f"[TestModel] RA count {ra_count} out of bounds (0–{num_frames})")

    # 5) Plot heatmap
    plt.figure(figsize=(0.1 * num_frames, 0.75 * NUM_CLASSES))
    plt.imshow(M, cmap='binary', aspect='auto', interpolation='nearest')
    plt.xlabel('Frame')
    plt.ylabel('Class')
    plt.title(f'Trigger: {trigger_type}, GT={gt_class}, BD={bd_class}')

    frames_int = [int(f) for f in frames]
    if num_frames > 50:
        step = 10
    elif num_frames > 20:
        step = 5
    else:
        step = 1
        
    plt.xticks(np.arange(0, num_frames, step), frames_int[::step])
    plt.yticks(np.arange(NUM_CLASSES), [f'Class {i}' for i in range(NUM_CLASSES)])
    plt.colorbar(label='Prediction Score')
    plt.tight_layout()

    # 6) Save
    out_dir = os.path.join(path, 'ptsd_frame_plots')
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(os.path.join(out_dir, f'{folder}_{trigger_type}_ptsd_plot.png'))
    plt.close()

    # write counts
    counts_file = os.path.join(out_dir, f'{folder}_{trigger_type}_counts.txt')
    with open(counts_file, 'w') as f:
        f.write(f'ASR Count: {asr_count}\n')
        f.write(f'RA Count: {ra_count}\n')
        f.write(f'Total Frames: {num_frames}\n')
        f.write(f'GT Class: {gt_class}\n')
        f.write(f'BD Class: {bd_class}\n')
        f.write(f'Trigger Type: {trigger_type}\n')
        f.write(f'Folder: {folder}\n')
        f.write(f'Attack Type: {"RMA" if bd_class != 0 else "ODA"}\n')
    
def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="BD Model Training")

    # Arguments for the training parameters
    parser.add_argument("--record_path", required=True, type=str, help="Record path")
    parser.add_argument("--recursive", action='store_true', help="Whether to recursively run the folders in the record path")
    parser.add_argument("--use_meta", action='store_true', help="Whether to use the PTSD Meta Dataset instead of the PTSD Dataset")
    parser.add_argument("--force_inference", action='store_true', help="Whether to force inference on the PTSD dataset even if the inference directory already exists")
    parser.add_argument("--rma_target_class", type=int, default=4, help="Target class for RMA attack")
    parser.add_argument("--save_images", action='store_true', help="Whether to save the images with the predictions")

    args = parser.parse_args()

    return args

def convert_bbox_format(bbox_format):

    # Convert from coordinate format to the format used by the model
    if bbox_format == 'xywh':
        return 'coco'
    elif bbox_format == 'xyxy':
        return 'pascal_voc'
    elif bbox_format == 'cxcywh':
        return 'yolo'
    else:
        raise ValueError(f"Unsupported bounding box format: {bbox_format}. Supported formats are 'xywh', 'xyxy', and 'cxcywh'.")

def build_loader(shared_base_path, split_name, filename, bd_flag, transform_fn, args):
    """
    Returns a DataLoader for one split (train/val/test) and one mode (clean or backdoor).
    """
    # 1) Load and wrap
    path = os.path.join(shared_base_path, filename)
    ds_wrapper = torch.load(path, weights_only=False)

    # 2) Convert bbox format
    bbox_fmt = convert_bbox_format(ds_wrapper.bbox_current_format)

    # 3) Get transform and final bbox format
    data_transform, bbox_return_fmt = transform_fn(bbox_fmt)

    # 4) Wrap in BDWrapper and set poison flag
    ds = BDWrapper(ds_wrapper, bbox_return_format=bbox_return_fmt, data_split=split_name, transform=data_transform)
    ds.__get_bd__(bd_flag)

    batch_size = 1
    collate = test_collate_fn

    # 7) Build DataLoader
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate)

    return loader

def initialize_ptsd_dataset(args, model_wrapper):

    base = 'PATH/BackdoorObjectDetection/data/ptsd_large' # YOU MUST SET THIS
    model_transform_fn, bbox_return_format = model_wrapper.transform_test('pascal_voc')

    if bbox_return_format != 'pascal_voc':
        raise ValueError(f"[TestModel] The model's bbox return format must be 'pascal_voc', but got '{bbox_return_format}'.")

    if args.use_meta:
        print(f'[TestModel] Using PTSD Meta Dataset')
        dataset = PTSDMetaWrapper(
            root=base,
            transform=model_transform_fn,
        )
    else:
        print(f'[TestModel] Using PTSD Dataset')
        dataset = PTSDWrapper(
            root=base,
            transform=model_transform_fn,
        )

    return dataset

def run_inference_ptsd(args, path, device):

    # Read the args.txt file inside the path
    args_file = os.path.join(path, 'args.txt')
    if not os.path.exists(args_file):
        raise FileNotFoundError(f"[TestModel] The specified args file does not exist: {args_file}")
    
    # Save each line (represented as KEY: VALUE) as an attribute of temp_args object
    temp_args = argparse.Namespace()
    with open(args_file, 'r') as f:
        for line in f:
            if ':' in line:
                key, value = line.split(':', 1)
                key = key.strip()
                value = value.strip()
                setattr(temp_args, key, value)

    print(f"[TestModel] Loaded arguments from {args_file}: {temp_args}")

    # Step 4: Build the Model
    model_wrapper = build_model(temp_args.model, 'ptsd', temp_args.model_config_path, device, path, distributed=False, local_rank=0)

    # Get the checkpoint path
    checkpoint_path = os.path.join(path, 'checkpoint.pth')
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"[TestModel] The specified checkpoint does not exist: {checkpoint_path}")
    
    dataset = initialize_ptsd_dataset(args, model_wrapper)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,  # Inference typically uses batch size of 1
        shuffle=False,
        num_workers=4,
    )

    # Run each image through the model and save the results to the save_path/ptsd_inference directory
    inference_dir = os.path.join(path, 'ptsd')
    if not os.path.exists(inference_dir):
        os.makedirs(inference_dir)

    # Create a dir called images inside the inference_dir
    images_dir = os.path.join(inference_dir, 'images')
    if not os.path.exists(images_dir):
        os.makedirs(images_dir)

    model_wrapper.model.eval()
    img_predictions = []

    idx = 0
    for img, annotation in tqdm(dataloader, desc="Processing images", unit="image"):

        img_tensor = img.to(device)
        with torch.no_grad():
            outputs = model_wrapper.model(img_tensor)

        # Calculate the iou between the predicted boxes and the ground truth boxes
        gt_boxes = annotation['bboxes']
        gt_boxes = torch.tensor(gt_boxes, dtype=torch.float32).to(device)
        gt_class = annotation['category_ids'][0]

        iou_values = torchvision.ops.box_iou(outputs[0]['boxes'], torch.tensor(gt_boxes).to(device))
        valid_boxes = iou_values.max(dim=1).values > 0.5

        if args.save_images:
            fig, ax = plt.subplots(1, figsize=(12, 9))
            img = img_tensor[0].cpu()  # Get the first image in the batch
            
            ax.imshow(img.permute(1, 2, 0).cpu().numpy())

            cmap = plt.get_cmap('tab20')

            drawn = 1
            for i, box in enumerate(outputs[0]['boxes']):
                x1, y1, x2, y2 = box.cpu().numpy()

                # Only draw boxes with iou > 0.5
                if iou_values[i].max() < 0.5:
                    continue

                colour = cmap(outputs[0]['labels'][i].item() % 20)

                rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor=colour, facecolor='none', alpha=0.3)
                ax.add_patch(rect)
                
                # Add the prediction label and score
                label = outputs[0]['labels'][i].item()
                score = outputs[0]['scores'][i].item()

                if drawn == 1:
                    # Add the text to the top left corner of the image
                    x_pos, y_pos = x1, y1
                elif drawn == 2:
                    # Add the text to the top right corner of the image
                    x_pos, y_pos = x2, y1
                elif drawn == 3:
                    # Add the text to the bottom left corner of the image
                    x_pos, y_pos = x1, y2
                elif drawn == 4:
                    # Add the text to the bottom right corner of the image
                    x_pos, y_pos = x2, y2
                else:
                    raise ValueError(f"[TestModel] More than 4 boxes drawn for image {idx}, which is unexpected.")
                
                ax.text(x_pos, y_pos, f'{label}: ({score:.2f})', color='white', fontsize=12, bbox=dict(facecolor=colour, alpha=0.5))
                drawn += 1
                if drawn > 4:
                    drawn = 1

            plt.axis('off')

            frame = annotation['frame'][0]
            folder = annotation['folder'][0]
            trigger_type = annotation['trigger_type'][0]

            plt.savefig(os.path.join(images_dir, f'inference_{folder}_{trigger_type}_{frame}.png'), bbox_inches='tight', pad_inches=0)
            plt.close(fig)



        # Format and save the predictions
        pred_boxes = outputs[0]['boxes'][valid_boxes].cpu().numpy()
        pred_labels = outputs[0]['labels'][valid_boxes].cpu().numpy()
        pred_scores = outputs[0]['scores'][valid_boxes].cpu().numpy()

        # If folder, trigger_type, and frame are not in the annotation, raise an error
        if not all(key in annotation for key in ['folder', 'trigger_type', 'frame']):
            raise KeyError(f"[TestModel] The annotation does not contain 'folder', 'trigger_type', or 'frame': {annotation}")

        # Check they also have length 1
        if len(annotation['folder']) != 1 or len(annotation['trigger_type']) != 1 or len(annotation['frame']) != 1:
            raise ValueError(f"[TestModel] The annotation 'folder', 'trigger_type', and 'frame' must have length 1: {annotation}")

        folder = annotation['folder'][0]
        trigger_type = annotation['trigger_type'][0]
        frame = annotation['frame'][0]

        img_predictions.append({
            'pred_boxes': pred_boxes.tolist(),
            'pred_labels': pred_labels.tolist(),
            'pred_scores': pred_scores.tolist(),
            'gt_boxes': gt_boxes.cpu().numpy().tolist(),
            'gt_class': gt_class.item() if isinstance(gt_class, torch.Tensor) else gt_class,
            'bd_class': args.rma_target_class if temp_args.data_attack == 'rma' or temp_args.data_attack == 'rma_morph' else 0,
            'folder': folder,
            'trigger_type': trigger_type,
            'frame': frame,
        })

        idx += 1

    # Save the predictions to a JSON file
    predictions_file = os.path.join(inference_dir, 'predictions.json')
    with open(predictions_file, 'w') as f:
        json.dump(img_predictions, f, indent=4)

def run_processing_ptsd(args, path):

    # Load the predictions.json from the inference directory
    pred_path = os.path.join(path, 'ptsd', 'predictions.json')
    if not os.path.exists(pred_path):
        raise FileNotFoundError(f"[TestModel] The specified predictions file does not exist: {pred_path}")
    
    with open(pred_path, 'r') as f:
        predictions = json.load(f)
    
    # Iterate over folders and trigger types
    unique_folders = {p['folder'] for p in predictions if 'folder' in p}
    for folder in unique_folders:
        folder_data = [p for p in predictions if p.get('folder') == folder]
        trigger_types = {p.get('trigger_type', 'unknown') for p in folder_data}
        for trigger in trigger_types:
            trigger_data = [p for p in folder_data if p.get('trigger_type') == trigger]
            base_path = os.path.join(path, 'ptsd')
            plot_ptsd_video(base_path, trigger_data, folder, trigger)

def main():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'[TestModel] Device set to: {device}')

    # Step 3: parse your own arguments
    args = parse_args()

    # if recursive, run the folders in the record path
    if args.recursive:
        print(f'[TestModel] Running recursively in the record path: {args.record_path}')
        dirs = os.listdir(args.record_path)
        for i, dir in enumerate(dirs):

            print(f'[TestModel] Processing directory {dir} (#{i + 1}/{len(dirs)})')
            full_path = os.path.join(args.record_path, dir)
            print(f'[TestModel] Full path: {full_path}')

            # If the directory starts with baseline, skip it
            if dir.startswith('baseline'):
                print(f'[TestModel] Skipping baseline directory: {full_path}')
                continue
            
            if os.path.isdir(full_path):
                print(f'[TestModel] Running on the full path: {full_path}')

                # If it doesnt contain a checkpoint.pth file the raise an error
                checkpoint_path = os.path.join(full_path, 'checkpoint.pth')
                if not os.path.exists(checkpoint_path):
                    print(f'[TestModel] No checkpoint found in {full_path}, skipping this directory.')
                    continue

                # Check if the inference directory already exists and contains a predictions.json file
                inference_dir = os.path.join(full_path, 'ptsd')
                predictions_file = os.path.join(inference_dir, 'predictions.json')

                if not (os.path.exists(predictions_file) and not args.force_inference):
                    print(f'[TestModel] Running inference on PTSD dataset in {full_path}')
                    run_inference_ptsd(args, full_path, device)
                else:
                    print(f'[TestModel] Skipping inference for {full_path} as predictions already exist and force_inference is not set.')

                # Run the processing on the PTSD dataset
                run_processing_ptsd(args, full_path)

    else:
        print(f'[TestModel] Running on the record path: {args.record_path}')
        
        inference_dir = os.path.join(args.record_path, 'ptsd')
        predictions_file = os.path.join(inference_dir, 'predictions.json')
        if not (os.path.exists(predictions_file) and not args.force_inference):
            print(f'[TestModel] Running inference on PTSD dataset in {args.record_path}')
            run_inference_ptsd(args, args.record_path, device)
        else:
            print(f'[TestModel] Skipping inference for {args.record_path} as predictions already exist and force_inference is not set.')

        # Run the processing on the PTSD dataset
        run_processing_ptsd(args, args.record_path)

if __name__ == "__main__":
    main()