import json
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from collections import defaultdict, Counter
import random


def load_category_mapping(annotation_file):
    try:
        with open(annotation_file, 'r') as f:
            coco_data = json.load(f)
        categories = coco_data['categories'] 
        return {cat['id']: cat['name'] for cat in categories}
    except (json.JSONDecodeError, KeyError) as e:
        print(f"Failed to load {annotation_file}: {e}")
        # try SCRATCH fallback
        fallback = annotation_file.replace('/software/ais2t/pytorch_datasets/coco/data/', 
                                          '/home/htc/USER/SCRATCH/datasets/concept-learning/coco/')
        if os.path.exists(fallback):
            print(f"Loading from fallback: {fallback}")
            with open(fallback, 'r') as f:
                coco_data = json.load(f)
            categories = coco_data['categories']
            return {cat['id']: cat['name'] for cat in categories}
        else:
            raise FileNotFoundError(f"Both primary and fallback annotation files failed")


def split_dataset(dataset, train_ratio=0.8):
    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size
    return random_split(dataset, [train_size, test_size])


class COCOLogicDataset(Dataset):
    def __init__(self, annotation_file, image_dir, category_id_to_name, transform=None, 
                 filter_no_labels=True, exclusive_label=True, exclusive_match_only=True,
                 log_statistics=False, version='small'):
        """
        annotation_file: path to COCO annotations JSON
        image_dir: path to the images folder
        category_id_to_name: dict mapping COCO category id to names
        transform: torchvision transforms
        filter_no_labels: if True, drops images that satisfy no logical classes
        exclusive_label: if True, only assign first matching class as 1 (others 0)
        exclusive_match_only: if True, only include images that match exactly one logical class

            exclusive_match_only | exclusive_label | Resulting Effect
            False | False | Multi-label dataset, overlapping classes allowed.
            False | True | Overlapping images included, but only first class is used in label.
            True | False | Only images with one class are included, label is still multi-hot (only one 1).
            True | True | Clean single-class dataset, label is one-hot — ideal for multi-class classification.
        """

        self.image_dir = image_dir
        self.transform = transform
        self.exclusive_label = exclusive_label
        self.exclusive_match_only = exclusive_match_only
        self.version = version

        # Load COCO annotations
        try:
            with open(annotation_file, 'r') as f:
                coco_data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError) as e:
            print(f"Failed to load {annotation_file}: {e}")
            # try SCRATCH fallback
            fallback = annotation_file.replace('/software/ais2t/pytorch_datasets/coco/data/', 
                                              '/home/htc/USER/SCRATCH/datasets/concept-learning/coco/')
            if os.path.exists(fallback):
                print(f"Loading from fallback: {fallback}")
                with open(fallback, 'r') as f:
                    coco_data = json.load(f)
            else:
                raise FileNotFoundError(f"Both primary and fallback annotation files failed: {annotation_file}, {fallback}")

        self.imgs = {img['id']: img for img in coco_data['images']}
        self.annotations = coco_data['annotations']

        self.image_to_categories = {}
        category_frequency = Counter()
        for ann in self.annotations:
            img_id = ann['image_id']
            cat_id = ann['category_id']
            cat_name = category_id_to_name[cat_id]
            self.image_to_categories.setdefault(img_id, set()).add(cat_name)
            category_frequency[cat_name] += 1

        if self.version == 7:
            self.logical_classes = [
                # 1. Conflicted Companions (Leash vs Licence). An image features either a dog or a car, but not both.
                (
                    "Conflicted Companions (Leash vs Licence)",
                    lambda cats: ("dog" in cats) ^ ("car" in cats),
                ),
                # 2. Ambiguous Pairs (Pet vs Ride Paradox). The image includes either a cat or a dog (but not both), 
                # and either a bicycle or a motorcycle (but not both).
                (
                    "Ambiguous Pairs (Pet vs Ride Paradox)",
                    lambda cats: (("cat" in cats) ^ ("dog" in cats))
                    and (("bicycle" in cats) ^ ("motorcycle" in cats)),
                ),
                # 3. Rural Animal Scene. The image includes one or more rural animals (cow, horse, or sheep) and no people.
                (
                    "Rural Animal Scene",
                    lambda cats: any(c in cats for c in ["cow", "horse", "sheep"])
                    and "person" not in cats,
                ),
                # 4. Animal Meet Traffic. The image contains a rural animal (horse, cow, or sheep) and a 
                # traffic-related object (car, bus, or traffic light).
                (
                    "Animal Meets Traffic",
                    lambda cats: any(c in cats for c in ["horse", "cow", "sheep"])
                    and any(c in cats for c in ["car", "bus", "traffic light"]),
                ),
                # 5. Empty Seat. The image includes indoor furniture (a couch or chair) but no person is present.
                (
                    "Empty Seat",
                    lambda cats: any(c in cats for c in ["couch", "chair"])
                    and "person" not in cats,
                ),
                # 6. Personal Transport XOR Car. A person is present alongside either a bicycle or a car — but not both.
                (
                    "Personal Transport XOR Car",
                    lambda cats: "person" in cats
                    and (("bicycle" in cats) ^ ("car" in cats)),
                ),
                # 7. Unlikely Breakfast Guests. The image shows a bowl (suggesting food) and at least one animal (dog, cat, horse, cow, or sheep).
                (
                    "Unlikely Breakfast Guests",
                    lambda cats: "bowl" in cats
                    and any(c in cats for c in ["dog", "cat", "horse", "cow", "sheep"]),
                )
            ]
        elif self.version == 8:
            self.logical_classes = [
                # 1. Conflicted Companions (Leash vs Licence). An image features either a dog or a car, but not both.
                (
                    "Conflicted Companions (Leash vs Licence)",
                    lambda cats: ("dog" in cats) ^ ("car" in cats),
                ),
                # 2. Ambiguous Pairs (Pet vs Ride Paradox). The image includes either a cat or a dog (but not both), 
                # and either a bicycle or a motorcycle (but not both).
                (
                    "Ambiguous Pairs (Pet vs Ride Paradox)",
                    lambda cats: (("cat" in cats) ^ ("dog" in cats))
                    and (("bicycle" in cats) ^ ("motorcycle" in cats)),
                ),
                # 3. Pair of Pets. Exactly two of the following animals are present: a cat, a dog, or a bird.
                (
                    "Pair of Pets",
                    lambda cats: sum(c in cats for c in ["cat", "dog", "bird"]) == 2,
                ),
                # 4. Rural Animal Scene. The image includes one or more rural animals (cow, horse, or sheep) and no people.
                (
                    "Rural Animal Scene",
                    lambda cats: any(c in cats for c in ["cow", "horse", "sheep"])
                    and "person" not in cats,
                ),
                # 5. Animal Meet Traffic. The image contains a rural animal (horse, cow, or sheep) and a 
                # traffic-related object (car, bus, or traffic light).
                (
                    "Animal Meets Traffic",
                    lambda cats: any(c in cats for c in ["horse", "cow", "sheep"])
                    and any(c in cats for c in ["car", "bus", "traffic light"]),
                ),
                # 6. Odd Ride Out. Exactly one of the following is present: a bicycle, motorcycle, car, or bus.
                (
                    "Odd Ride Out",
                    lambda cats: sum(
                        c in cats for c in ["bicycle", "motorcycle", "bus", "car"]
                    )
                    == 1,
                ),
                # 7. Personal Transport XOR Car. A person is present alongside either a bicycle or a car — but not both.
                (
                    "Personal Transport XOR Car",
                    lambda cats: "person" in cats
                    and (("bicycle" in cats) ^ ("car" in cats)),
                ),
                # 8. Unlikely Breakfast Guests. The image shows a bowl (suggesting food) and at least one animal (dog, cat, horse, cow, or sheep).
                (
                    "Unlikely Breakfast Guests",
                    lambda cats: "bowl" in cats
                    and any(c in cats for c in ["dog", "cat", "horse", "cow", "sheep"]),
                )
            ]
        elif self.version == 10:
            self.logical_classes = [
                # 1. Ambiguous Pairs (Pet vs Ride Paradox). The image includes either a cat or a dog (but not both),
                # and either a bicycle or a motorcycle (but not both).
                (
                    "Ambiguous Pairs (Pet vs Ride Paradox)",
                    lambda cats: (("cat" in cats) ^ ("dog" in cats))
                    and (("bicycle" in cats) ^ ("motorcycle" in cats)),
                ),
                # 2. Pair of Pets. Exactly two of the following animal categories are present: a cat, a dog, or a bird.
                (
                    "Pair of Pets",
                    lambda cats: sum(c in cats for c in ["cat", "dog", "bird"]) == 2,
                ),
                # 3. Rural Animal Scene. The image includes one or more rural animals (cow, horse, or sheep) and no people.
                (
                    "Rural Animal Scene",
                    lambda cats: any(c in cats for c in ["cow", "horse", "sheep"])
                    and "person" not in cats,
                ),
                # 4. Conflicted Companions (Leash vs Licence). An image features either a dog or a car, but not both.
                (
                    "Conflicted Companions (Leash vs Licence)",
                    lambda cats: ("dog" in cats) ^ ("car" in cats),
                ),
                # 5. Animal Meet Traffic. The image contains a rural animal (horse, cow, or sheep) and a
                # traffic-related object (car, bus, or traffic light).
                (
                    "Animal Meets Traffic",
                    lambda cats: any(c in cats for c in ["horse", "cow", "sheep"])
                    and any(c in cats for c in ["car", "bus", "traffic light"]),
                ),
                # 6. Occupied Interior. The image includes furniture (a couch or chair) and at least one person.
                (
                    "Occupied Interior",
                    lambda cats: any(c in cats for c in ["couch", "chair"])
                    and "person" in cats
                    and sum(c == "person" for c in cats) == 1,
                ),
                # 7. Empty Seat. The image includes indoor furniture (a couch or chair) but no person is present.
                (
                    "Empty Seat",
                    lambda cats: any(c in cats for c in ["couch", "chair"])
                    and "person" not in cats,
                ),
                # 8. Odd Ride Out. Exactly one of the following categories is present: a bicycle, motorcycle, car, or bus.
                (
                    "Odd Ride Out",
                    lambda cats: sum(
                        c in cats for c in ["bicycle", "motorcycle", "bus", "car"]
                    )
                    == 1,
                ),
                # 9. Personal Transport XOR Car. A person is present alongside either a bicycle or a car — but not both.
                (
                    "Personal Transport XOR Car",
                    lambda cats: "person" in cats
                    and (("bicycle" in cats) ^ ("car" in cats)),
                ),
                # 10. Unlikely Breakfast Guests. The image shows a bowl (suggesting food) and at least one animal (dog, cat, horse, cow, or sheep).
                (
                    "Unlikely Breakfast Guests",
                    lambda cats: "bowl" in cats
                    and any(c in cats for c in ["dog", "cat", "horse", "cow", "sheep"]),
                )
            ]

        total_images = len(self.imgs)
        kept_images = 0
        class_counts = defaultdict(int)
        class_cooccurrence = Counter()

        self.image_ids = []
        for img_id in self.imgs:
            cats = self.image_to_categories.get(img_id, set())
            labels = [int(fn(cats)) for _, fn in self.logical_classes]

            if filter_no_labels and not any(labels):
                continue

            if exclusive_match_only and sum(labels) != 1:
                continue

            self.image_ids.append(img_id)
            kept_images += 1

            label_tuple = tuple(labels)
            class_cooccurrence[label_tuple] += 1

            for i, val in enumerate(labels):
                if val:
                    class_counts[self.logical_classes[i][0]] += 1

        if log_statistics:
            print(f"\nLogicalCOCODataset: Loaded {total_images} images.")
            print(f"Filtered to {kept_images} images after applying logical class filters.\n")

            print("Per-Class Image Count:")
            for class_name, _ in self.logical_classes:
                print(f" - {class_name:<25}: {class_counts[class_name]}")

            print("\nTop 20 Class Co-occurrence Patterns:")
            for pattern, count in class_cooccurrence.most_common(20):
                pattern_str = ', '.join([self.logical_classes[i][0] for i, v in enumerate(pattern) if v])
                print(f" - [{pattern_str or 'None'}] : {count} images")

            print("\nTop 20 Category Frequencies:")
            for cat, count in category_frequency.most_common(20):
                print(f" - {cat:<20}: {count} annotations")

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_info = self.imgs[img_id]
        img_path = os.path.join(self.image_dir, img_info['file_name'])

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        categories = self.image_to_categories.get(img_id, set())
        labels = [int(fn(categories)) for _, fn in self.logical_classes]

        if self.exclusive_label:
            exclusive = [0] * len(labels)
            for i, val in enumerate(labels):
                if val:
                    exclusive[i] = 1
                    break
            labels = exclusive

        # labels = torch.tensor(labels, dtype=torch.float if not self.exclusive_label else torch.long)
        # return image, labels
        label_index = torch.tensor(labels.index(1), dtype=torch.long)
        return image, label_index
