import os
import csv
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timm
from tqdm import tqdm
import numpy as np
import random
import argparse

from src.vit_segmentation import PostPoolingDetector


class VOCDetectionDataset(torchvision.datasets.VOCDetection):
    """
    Custom Dataset Class for the VOC Detection (Segmentation)
    """
    def __init__(self, root, year='2007', image_set='trainval', transforms=None):
        # Disable parent's transform by passing None.
        super().__init__(root, year=year, image_set=image_set, download=True, transform=None)
        self.img_transforms = transforms
        # VOC 2007 classes.
        self.classes = ["aeroplane", "bicycle", "bird", "boat", "bottle",
                        "bus", "car", "cat", "chair", "cow", "diningtable",
                        "dog", "horse", "motorbike", "person", "pottedplant",
                        "sheep", "sofa", "train", "tvmonitor"]

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        anno = target['annotation']
        # Get original image size.
        orig_width = int(anno['size']['width'])
        orig_height = int(anno['size']['height'])
        
        # Take the first object's annotation.
        obj = anno['object']
        if isinstance(obj, list):
            obj = obj[0]
        label = self.classes.index(obj['name'])
        
        # Extract bounding box and normalize to [0, 1].
        bndbox = obj['bndbox']
        xmin = float(bndbox['xmin'])
        ymin = float(bndbox['ymin'])
        xmax = float(bndbox['xmax'])
        ymax = float(bndbox['ymax'])
        bbox = [xmin / orig_width, ymin / orig_height, xmax / orig_width, ymax / orig_height]
        bbox = torch.tensor(bbox, dtype=torch.float32)
        
        # Apply image-only transforms.
        if self.img_transforms:
            img = self.img_transforms(img)
        return img, label, bbox

def compute_iou(box1, box2):
    """
    Compute Intersection over Union (IoU)
    """
    xA = max(box1[0], box2[0])
    yA = max(box1[1], box2[1])
    xB = min(box1[2], box2[2])
    yB = min(box1[3], box2[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    unionArea = box1Area + box2Area - interArea
    return interArea / unionArea if unionArea > 0 else 0

def train_one_epoch(model, loader, optimizer, cls_criterion, bbox_criterion, device, lambda_bbox=1.0):
    """
    Training Function
    """
    model.train()
    running_loss = 0.0
    running_cls_correct = 0
    running_loc_correct = 0
    total_samples = 0
    
    for imgs, labels, bboxes in tqdm(loader, desc="Training", leave=False):
        imgs = imgs.to(device)
        labels = labels.to(device)
        bboxes = bboxes.to(device)
        
        optimizer.zero_grad()
        cls_logits, bbox_preds = model(imgs)
        cls_loss = cls_criterion(cls_logits, labels)
        bbox_loss = bbox_criterion(bbox_preds, bboxes)
        loss = cls_loss + lambda_bbox * bbox_loss
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * imgs.size(0)
        preds = torch.argmax(cls_logits, dim=1)
        running_cls_correct += (preds == labels).sum().item()
        
        batch_loc_correct = 0
        for i in range(imgs.size(0)):
            pred_box = bbox_preds[i].detach().cpu().numpy()
            true_box = bboxes[i].detach().cpu().numpy()
            if compute_iou(pred_box, true_box) >= 0.5:
                batch_loc_correct += 1
        running_loc_correct += batch_loc_correct
        total_samples += imgs.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_cls_acc = 100. * running_cls_correct / total_samples
    epoch_loc_acc = 100. * running_loc_correct / total_samples
    return epoch_loss, epoch_cls_acc, epoch_loc_acc

def evaluate(model, loader, cls_criterion, bbox_criterion, device, lambda_bbox=1.0):
    """
    Evaluation Functions
    """
    model.eval()
    running_loss = 0.0
    running_cls_correct = 0
    running_loc_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for imgs, labels, bboxes in tqdm(loader, desc="Evaluating", leave=False):
            imgs = imgs.to(device)
            labels = labels.to(device)
            bboxes = bboxes.to(device)
            
            cls_logits, bbox_preds = model(imgs)
            cls_loss = cls_criterion(cls_logits, labels)
            bbox_loss = bbox_criterion(bbox_preds, bboxes)
            loss = cls_loss + lambda_bbox * bbox_loss
            
            running_loss += loss.item() * imgs.size(0)
            preds = torch.argmax(cls_logits, dim=1)
            running_cls_correct += (preds == labels).sum().item()
            
            batch_loc_correct = 0
            for i in range(imgs.size(0)):
                pred_box = bbox_preds[i].detach().cpu().numpy()
                true_box = bboxes[i].detach().cpu().numpy()
                if compute_iou(pred_box, true_box) >= 0.5:
                    batch_loc_correct += 1
            running_loc_correct += batch_loc_correct
            total_samples += imgs.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_cls_acc = 100. * running_cls_correct / total_samples
    epoch_loc_acc = 100. * running_loc_correct / total_samples
    return epoch_loss, epoch_cls_acc, epoch_loc_acc

###############################################
# Experiment Runner with Command-Line Args
###############################################
def run_experiment(args, run_id, device, train_loader, test_loader, pooling):
    # Set random seeds for reproducibility.
    torch.manual_seed(args.seed + run_id)
    np.random.seed(args.seed + run_id)
    random.seed(args.seed + run_id)
    
    NUM_CLASSES = 20
    # Select the ViT model based on model_type argument.
    if args.model_type == "small_vit":
        model_name = "vit_small_patch16_224"
    elif args.model_type == "base_vit":
        model_name = "vit_base_patch16_224"
    else:
        raise ValueError("Unsupported model type. Choose 'small_vit' or 'base_vit'.")
    
    base_model = timm.create_model(
        model_name,
        pretrained=True,
        num_classes=NUM_CLASSES,
        global_pool=None
    )
    base_model.to(device)
    
    # Freeze backbone parameters.
    for param in base_model.parameters():
        param.requires_grad = False
    
    # Initialize the detector with the chosen pooling strategy.
    model = PostPoolingDetector(base_model,
                                num_classes=NUM_CLASSES,
                                pooling_type=pooling,
                                num_iterations=args.num_iterations)
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    cls_criterion = nn.CrossEntropyLoss()
    bbox_criterion = nn.L1Loss()
    
    for epoch in range(args.epochs):
        train_one_epoch(model, train_loader, optimizer, cls_criterion, bbox_criterion, device, args.lambda_bbox)
    
    test_loss, test_cls_acc, test_loc_acc = evaluate(model, test_loader, cls_criterion, bbox_criterion, device, args.lambda_bbox)
    return test_cls_acc, test_loc_acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VOC Detection Experiment with ViT Pooling")
    parser.add_argument('--data_path', type=str, default='./VOCdevkit', help='Path to the VOC dataset (VOCdevkit folder).')
    parser.add_argument('--year', type=str, default='2007', help='VOC year (e.g., 2007).')
    parser.add_argument('--pooling', type=str, default='attention', 
                        choices=['cls', 'avg', 'sum', 'attention', 'weighted_avg', 'max'],
                        help='Pooling strategy to use.')
    parser.add_argument('--model_type', type=str, default='small_vit',
                        choices=['small_vit', 'base_vit'],
                        help="Type of ViT model to use: 'small_vit' or 'base_vit'.")
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate.')
    parser.add_argument('--runs', type=int, default=2, help='Number of experimental runs per pooling strategy.')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers.')
    parser.add_argument('--num_iterations', type=int, default=10, help='Number of iterations for svd pooling.')
    parser.add_argument('--lambda_bbox', type=float, default=1.0, help='Weight for bbox regression loss.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = VOCDetectionDataset(root=args.data_path, year=args.year, image_set='trainval', transforms=transform)
    test_dataset = VOCDetectionDataset(root=args.data_path, year=args.year, image_set='test', transforms=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=args.num_workers)
    

    print(f"\nRunning experiments for pooling type: '{args.pooling}'")
    for run in range(args.runs):
        current_seed = args.seed + run
        print(f"  Experiment run {run+1}/{args.runs} with seed {current_seed}")
        test_cls_acc, test_loc_acc = run_experiment(args, run, device, train_loader, test_loader, pooling=args.pooling)
        print(f"Run {run+1}: Test Cls Acc = {test_cls_acc:.2f}%, Test Loc Acc = {test_loc_acc:.2f}%")
