#!/usr/bin/env python3
import argparse
import os
import numpy as np
from PIL import Image
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
from torchvision.transforms.functional import resize
from torchvision.models.segmentation import (
    deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights,
    deeplabv3_resnet101, DeepLabV3_ResNet101_Weights,
    deeplabv3_resnet50, DeepLabV3_ResNet50_Weights,
    fcn_resnet101, FCN_ResNet101_Weights,
    fcn_resnet50, FCN_ResNet50_Weights
)


# (CocoDataset class unchanged from your version)
# (compute_metrics, dice_loss, encode_rle unchanged from your version)


def get_model_and_weights(name):
    models = {
        'deeplabv3_resnet101': (deeplabv3_resnet101, DeepLabV3_ResNet101_Weights.DEFAULT),
        'deeplabv3_resnet50': (deeplabv3_resnet50, DeepLabV3_ResNet50_Weights.DEFAULT),
        'deeplabv3_mobilenet_v3_large': (deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT),
        'fcn_resnet101': (fcn_resnet101, FCN_ResNet101_Weights.DEFAULT),
        'fcn_resnet50': (fcn_resnet50, FCN_ResNet50_Weights.DEFAULT),
    }
    return models[name]


def main(args):
    model_class, weights = get_model_and_weights(args.model)
    model = model_class(weights=weights).eval().cuda()
    transform = weights.transforms()

    dataset = CocoDataset(
        image_root_dir=args.image_root_dir,
        annotation_file_path=args.annotation_file_path,
        transform=transform,
        filter_images=False
    )
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    image_ids, mean_ious, per_class_ious, pixel_accs = [], [], [], []
    rle_pred_masks, rle_true_masks = [], []

    with torch.no_grad():
        for img_id, image, mask in tqdm(dataloader, desc="Inferring..."):
            image = image.cuda()
            logits = model(image)["out"]
            pred = logits.argmax(1).squeeze(0).cpu()
            gt = mask.squeeze(0)

            image_ids.append(img_id.item())
            miou, acc, ious = compute_metrics(pred, gt)
            mean_ious.append(miou)
            per_class_ious.append(ious)
            pixel_accs.append(acc)
            rle_pred_masks.append(encode_rle(pred))
            rle_true_masks.append(encode_rle(gt))

    out_dir = os.path.join("outputs", args.model)
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{args.split}_inference.pth")

    torch.save({
        'image_ids': image_ids,
        'mean_ious': mean_ious,
        'per_class_ious': per_class_ious,
        'pixelwise_accs': pixel_accs,
        'rle_pred_masks': rle_pred_masks,
        'rle_true_masks': rle_true_masks
    }, out_path)

    print(f"Saved results to {out_path}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_root_dir', type=str, required=True)
    parser.add_argument('--annotation_file_path', type=str, required=True)
    parser.add_argument('--split', type=str, required=True, choices=['train', 'val'])
    parser.add_argument('--model', type=str, required=True,
                        choices=['deeplabv3_resnet101', 'deeplabv3_resnet50',
                                 'deeplabv3_mobilenet_v3_large',
                                 'fcn_resnet101', 'fcn_resnet50'])
    args = parser.parse_args()
    main(args)
