import torch
import sys
from backbones.fastrcnn import FastRCNN
from datasets.utils.clevr_creation import CLEVR_Preprocess
import torchvision.transforms as T
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm

def test_boundingbox(model,dataloader, device, ep):
    model.eval()
    batch=next(iter(dataloader))
    imgs,_,_=batch
    imgs = [img.to(device) for img in imgs]
    predictions=model(imgs)
    img=imgs[0]
    print(predictions[0])
    boxes=predictions[0]["boxes"].cpu().detach().numpy()
    scores= predictions[0]["scores"].cpu().detach().numpy()
    plot_bbox_predicted(img,boxes,scores, filename=f"predicted_boxes_{ep}.png")
    model.train()


def train(model,dataloader,optim,lr_scheduler, device):
    save_dir = "fastcnn_weights"
    os.makedirs(save_dir, exist_ok=True)

    model.train()
    for ep in tqdm(range(100)):
        for idx,data in enumerate(dataloader):
            optim.zero_grad()
            imgs,targets,_=data
            imgs = [img.to(device) for img in imgs]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            results=model(imgs,targets)
            loss=results["loss_box_reg"] + results["loss_rpn_box_reg"] + results["loss_objectness"]
            print("Loss",loss.item())
            loss.backward()
            optim.step()

        if (ep % 10)==0:
            print("Saving...")
            test_boundingbox(model,dataloader, device, ep)
            torch.save(model.state_dict(),os.path.join(save_dir, f"model_{ep}.pt"))

        lr_scheduler.step()
        
    print("Saving...")
    test_boundingbox(model,dataloader, device, 100)
    torch.save(model.state_dict(),os.path.join(save_dir, "model_final.pt"))

def plot_bbox_predicted(img, boxes, scores, save_dir="fastrcnn", filename="predicted_bboxes.png"):
    os.makedirs(save_dir, exist_ok=True)

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    if img.max() > 1:
        img = img / 255.0

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(img)

    for box, score in zip(boxes, scores):

        x_min, y_min, x_max, y_max = box
        width, height = x_max - x_min, y_max - y_min

        rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

        ax.text(
            x_min,
            y_min - 5,
            f"{score:.2f}",
            color='red',
            fontsize=12,
            backgroundcolor='white',
            alpha=0.8,
        )

    # Save the image
    save_path = os.path.join(save_dir, filename)
    plt.axis('off')
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Bounding box predictions saved to {save_path}")

def save_bbox_predicted(img, boxes, scores, save_dir="fastrcnn", filename="bbox_predictions.png"):
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Convert the image tensor to numpy (if needed) and handle shape
    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()  # Convert from (C, H, W) to (H, W, C)

    # Create a figure and axis for the image
    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(img)

    # Add bounding boxes and scores
    for box, score in zip(boxes, scores):
        x_min, y_min, x_max, y_max = box
        width, height = x_max - x_min, y_max - y_min

        # Add the rectangle
        rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

        # Add score text
        ax.text(x_min, y_min - 5, f"{score:.2f}", color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.5))

    # Remove axis for cleaner visualization
    ax.axis("off")

    # Save the image
    save_path = os.path.join(save_dir, filename)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()  # Close the figure to free memory
    print(f"Bounding box predictions saved to {save_path}")

def collate_fn(batch):
    images, targets, *rest = zip(*batch)
    return list(images), list(targets), rest


def main():
    device = "cuda:2"

    model = FastRCNN(num_classes = 15)
    model= model.to(device)
    optim= torch.optim.SGD([p for p in model.parameters() if p.requires_grad],lr=0.001, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optim,
        step_size=3,
        gamma=0.1
    )

    dataset=CLEVR_Preprocess("clevr", "train")

    dataloader= torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=16,
        num_workers=0,
        drop_last=True,
        collate_fn=collate_fn
    )
    train(model,dataloader,optim,lr_scheduler, device)

if __name__ == "__main__":
    main()
