import os
import json
import random
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

class CustomDataset(Dataset):
    """
    Dataset layout:
      dataset/
        ├── train/                      # training images
        ├── val/                        # validation images
        └── db/
            ├── DB.csv                  # columns: Label, Zone, Path, DB, ImageName
            └── db_support_set.json     # mapping: split -> query image -> support per class
    """
    def __init__(self, root_dir, split='Train', transform=None,
                 num_classes=10, support_size=5, random_support=False):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.num_classes = num_classes
        self.support_size = support_size
        self.random_support = random_support

        csv_path = os.path.join(root_dir, 'db', 'DB.csv')
        self.data_df = pd.read_csv(csv_path)
        self.data_df = self.data_df[self.data_df['DB'] == self.split].reset_index(drop=True)

        unique_labels = sorted(self.data_df['Label'].unique())
        self.label_map = {orig: new for new, orig in enumerate(unique_labels)}
        print("Label mapping:", self.label_map)

        support_json_path = os.path.join(root_dir, 'db', 'db_support_set.json')
        if os.path.exists(support_json_path):
            with open(support_json_path, 'r') as f:
                self.support_data = json.load(f)[self.split]
        else:
            self.support_data = {}

        if self.random_support:
            csv_full = pd.read_csv(csv_path)
            csv_full = csv_full[csv_full['DB'] == 'Train']
            self.train_imgs_by_class = {}
            for _, r in csv_full.iterrows():
                self.train_imgs_by_class.setdefault(r['Label'], []).append(r['ImageName'])

            fixed_json = os.path.join(root_dir, 'db', f'fixed_random_support_{self.support_size}.json')
            if os.path.exists(fixed_json):
                with open(fixed_json, 'r') as f:
                    self.fixed_supports = json.load(f)
            else:
                rng = random.Random(42)
                self.fixed_supports = {}
                for lbl, imgs in self.train_imgs_by_class.items():
                    if len(imgs) >= self.support_size:
                        chosen = rng.sample(imgs, self.support_size)
                    else:
                        chosen = rng.choices(imgs, k=self.support_size)
                    self.fixed_supports[str(lbl)] = chosen
                with open(fixed_json, 'w') as f:
                    json.dump(self.fixed_supports, f)
        else:
            self.fixed_supports = {}

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        row = self.data_df.iloc[idx]
        image_path = os.path.join(self.root_dir, 'train' if self.split == 'Train' else 'val', row['ImageName'])
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        label = self.label_map[int(row['Label'])]
        image_name = row['ImageName']

        support_info = self.support_data.get(image_name, {})

        if self.support_size == 0:
            support_set_tensor = torch.empty((self.num_classes, 0, *image.shape))
        else:
            support_list = []
            for key, value in self.label_map.items():
                if self.random_support:
                    support_imgs = self.fixed_supports.get(str(key), [])
                else:
                    support_imgs = support_info.get(str(key), [])

                assert self.support_size <= len(support_imgs), "support size must be lower or equal to len(support_imgs)!"
                class_support = []
                for j in range(self.support_size):
                    if j < len(support_imgs):
                        supp_img_path = os.path.join(self.root_dir, 'train', support_imgs[j])
                        assert os.path.exists(supp_img_path), "Support set image must be from the training set!"
                        if os.path.exists(supp_img_path):
                            supp_img = Image.open(supp_img_path).convert('RGB')
                            if self.transform:
                                supp_img = self.transform(supp_img)
                    class_support.append(supp_img.unsqueeze(0))
                class_support = torch.cat(class_support, dim=0)
                support_list.append(class_support.unsqueeze(0))
            support_set_tensor = torch.cat(support_list, dim=0)

        return image, label, support_set_tensor
