import torch
import sys
from backbones.retina import RetinaNetModel
from datasets.utils.clevr_creation import CLEVR_Preprocess
import torchvision.transforms as T
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm


def collate_fn(batch):
    images, targets, *rest = zip(*batch)
    return list(images), list(targets), rest


def train(model, dataloader, optim, lr_scheduler, device):

    save_dir = "retina"
    os.makedirs(save_dir, exist_ok=True)

    model.train()

    for ep in tqdm(range(100)):
        for idx, data in enumerate(dataloader):
            optim.zero_grad()
            imgs, targets, _ = data
            imgs = [img.to(device) for img in imgs]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            results = model(imgs, targets)
            loss = results['bbox_regression']
            print("Loss", loss.item())
            loss.backward()
            optim.step()

        if (ep % 10) == 0:
            print("Saving...")
            test_boundingbox(model, dataloader, device, ep)
            torch.save(model.state_dict(), os.path.join(save_dir, f"model_{ep}.pt"))

        lr_scheduler.step()

    torch.save(model.state_dict(), os.path.join(save_dir, f"model_final.pt"))

def plot_bbox_predicted(img, boxes, scores, save_dir="retina_out", filename="predicted_bboxes.png"):
    os.makedirs(save_dir, exist_ok=True)

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    if img.max() > 1:
        img = img / 255.0

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(img)

    for box, score in zip(boxes, scores):

        x_min, y_min, x_max, y_max = box
        width, height = x_max - x_min, y_max - y_min

        rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

        ax.text(
            x_min,
            y_min - 5,
            f"{score:.2f}",
            color='red',
            fontsize=12,
            backgroundcolor='white',
            alpha=0.8,
        )

    # Save the image
    save_path = os.path.join(save_dir, filename)
    plt.axis('off')
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Bounding box predictions saved to {save_path}")

def test_boundingbox(model, dataloader, device, ep):
    model.eval()
    batch = next(iter(dataloader))
    imgs, _, _ = batch
    imgs = [img.to(device) for img in imgs]
    predictions = model(imgs)
    img = imgs[0]
    boxes = predictions[0]["boxes"].cpu().detach().numpy()
    scores = predictions[0]["scores"].cpu().detach().numpy()
    plot_bbox_predicted(img, boxes, scores, filename=f"predicted_boxes_{ep}.png")
    model.train()

def main():
    device = "cuda:0"

    model = RetinaNetModel(num_classes=15)
    model = model.to(device)
    optim = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.001, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=3, gamma=0.1)

    dataset = CLEVR_Preprocess("clevr", "train")

    dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=16,
        num_workers=0,
        drop_last=True,
        collate_fn=collate_fn
    )
    
    train(model, dataloader, optim, lr_scheduler, device)

if __name__ == "__main__":
    main()
