import torch
import sys
from backbones.fastrcnn import FastRCNN
from datasets.utils.clevr_creation import CLEVR_Preprocess
import torchvision.transforms as T
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
import torchvision.transforms.functional as F

def plot_gt_box(img, target, save_dir="fastrcnn", filename="predicted_bboxes.png"):
    os.makedirs(save_dir, exist_ok=True)

    img = F.to_pil_image(img.cpu())
    boxes = target['boxes'].cpu().numpy()
    labels = target['labels'].cpu().numpy()
    
    # Plot the image
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    
    # Plot each bounding box
    for box, label in zip(boxes, labels):
        x_min, y_min, x_max, y_max = box
        rect = plt.Rectangle(
            (x_min, y_min), x_max - x_min, y_max - y_min,
            edgecolor='red', facecolor='none', linewidth=2
        )
        plt.gca().add_patch(rect)
        plt.text(
            x_min, y_min - 5, f"Label: {label}",
            color='red', fontsize=12, backgroundcolor='white'
        )
    
    plt.axis('off')

    output_path = os.path.join(save_dir, f"gt.png")
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()

def test_boundingbox(model,dataloader, device, ep):
    model.eval()
    batch=next(iter(dataloader))
    imgs,targets,_=batch
    imgs = [img.to(device) for img in imgs]
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    predictions=model(imgs)
    img=imgs[0]
    target=targets[0]
    boxes=predictions[0]["boxes"].cpu().detach().numpy()
    scores= predictions[0]["scores"].cpu().detach().numpy()
    plot_gt_box(img, target)
    plot_bbox_predicted(img,boxes,scores, filename=f"predicted_boxes_{ep}.png")
    model.train()


def test(model,dataloader,optim, device):
    save_dir = "fastcnn_weights"
    print("Loading weights")
    model.load_state_dict(torch.load(os.path.join(save_dir, "model_final.pt")))        
    print("Saving a test...")
    test_boundingbox(model,dataloader, device, -1)

def plot_bbox_predicted(img, boxes, scores, save_dir="fastrcnn", filename="predicted_bboxes.png"):
    os.makedirs(save_dir, exist_ok=True)

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    if img.max() > 1:
        img = img / 255.0

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(img)

    for box, score in zip(boxes, scores):

        x_min, y_min, x_max, y_max = box
        width, height = x_max - x_min, y_max - y_min

        rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

        ax.text(
            x_min,
            y_min - 5,
            f"{score:.2f}",
            color='red',
            fontsize=12,
            backgroundcolor='white',
            alpha=0.8,
        )

    # Save the image
    save_path = os.path.join(save_dir, filename)
    plt.axis('off')
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Bounding box predictions saved to {save_path}")

def save_bbox_predicted(img, boxes, scores, save_dir="fastrcnn", filename="bbox_predictions.png"):
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Convert the image tensor to numpy (if needed) and handle shape
    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()  # Convert from (C, H, W) to (H, W, C)

    # Create a figure and axis for the image
    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(img)

    # Add bounding boxes and scores
    for box, score in zip(boxes, scores):
        x_min, y_min, x_max, y_max = box
        width, height = x_max - x_min, y_max - y_min

        # Add the rectangle
        rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

        # Add score text
        ax.text(x_min, y_min - 5, f"{score:.2f}", color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.5))

    # Remove axis for cleaner visualization
    ax.axis("off")

    # Save the image
    save_path = os.path.join(save_dir, filename)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()  # Close the figure to free memory
    print(f"Bounding box predictions saved to {save_path}")

def collate_fn(batch):
    images, targets, *rest = zip(*batch)
    return list(images), list(targets), rest


def main():
    device = "cuda:2"

    model = FastRCNN(num_classes = 15)
    model= model.to(device)
    optim= torch.optim.SGD([p for p in model.parameters() if p.requires_grad],lr=0.001, momentum=0.9, weight_decay=0.0005)

    dataset=CLEVR_Preprocess("clevr", "train")

    dataloader= torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=16,
        num_workers=0,
        drop_last=True,
        collate_fn=collate_fn
    )
    test(model,dataloader,optim, device)

if __name__ == "__main__":
    main()
