import logging
import os
import pickle
import random
import re
from typing import Dict, List

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from clients.base import Client
from torchvision import transforms


class UnlabeledImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = [
            os.path.join(folder_path, f) for f in os.listdir(folder_path)
        ]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image
    
class MultiModalDataset(Dataset):
    def __init__(self, folder_path, captions_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = [
            os.path.join(folder_path, f) for f in os.listdir(folder_path)
        ]
        with open(captions_path, "rb") as f:
            self.captions = pickle.load(f)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
            
        # Get the caption
        caption = self.captions[os.path.basename(img_path)]  


        return image, caption


transform_public = transforms.Compose(
    [
        transforms.Resize((224, 224), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

transform_public = transforms.Compose(
    [
        transforms.Resize((224, 224), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


def setup_dataset(args) -> Dict:
    # Data dictionary
    datadict = {}

    if args.num_MLC_clients > 0:
        MLC_clients_train_ann_files = [
            os.path.join(args.data_dir, "MLC_setup", "train", f"client_{i}.json")
            for i in range(args.num_MLC_clients)
        ]
        MLC_clients_test_ann_files = [
            os.path.join(args.data_dir, "MLC_setup", "test", f"client_{i}.json")
            for i in range(args.num_MLC_clients)
        ]
        root_train_image_folder = os.path.join(args.root_train_image_folder)
        root_val_image_folder = os.path.join(args.root_val_image_folder)

        datadict["MLC_clients_train_ann_files"] = MLC_clients_train_ann_files
        datadict["MLC_clients_test_ann_files"] = MLC_clients_test_ann_files
        datadict["root_train_image_folder"] = root_train_image_folder
        datadict["root_val_image_folder"] = root_val_image_folder

    if args.num_semantic_segmentation_clients > 0:
        semantic_seg_clients_train_ann_files = [
            os.path.join(args.data_dir, "semseg_setup", "train", f"client_{i}.json")
            for i in range(args.num_semantic_segmentation_clients)
        ]
        semantic_seg_clients_test_ann_files = [
            os.path.join(args.data_dir, "semseg_setup", "test", f"client_{i}.json")
            for i in range(args.num_semantic_segmentation_clients)
        ]
        root_train_image_folder = os.path.join(args.root_train_image_folder)
        root_val_image_folder = os.path.join(args.root_val_image_folder)

        datadict[
            "semantic_seg_clients_train_ann_files"
        ] = semantic_seg_clients_train_ann_files
        datadict[
            "semantic_seg_clients_test_ann_files"
        ] = semantic_seg_clients_test_ann_files
        datadict["root_train_image_folder"] = root_train_image_folder
        datadict["root_val_image_folder"] = root_val_image_folder

    if args.num_IC100_clients > 0:
        IC100_clients_train_files = [
            os.path.join(args.data_dir, "IC", "CIFAR100", "train", f"client_{i}.pkl")
            for i in range(args.num_IC100_clients)
        ]
        IC100_clients_test_files = [
            os.path.join(args.data_dir, "IC", "CIFAR100", "test", f"client_{i}.pkl")
            for i in range(args.num_IC100_clients)
        ]

        datadict["IC100_clients_train_files"] = IC100_clients_train_files
        datadict["IC100_clients_test_files"] = IC100_clients_test_files

    if args.num_IC10_clients > 0:
        IC10_clients_train_files = [
            os.path.join(args.data_dir, "IC", "CIFAR10", "train", f"client_{i}.pkl")
            for i in range(args.num_IC10_clients)
        ]
        IC10_clients_test_files = [
            os.path.join(args.data_dir, "IC", "CIFAR10", "test", f"client_{i}.pkl")
            for i in range(args.num_IC10_clients)
        ]
        datadict["IC10_clients_train_files"] = IC10_clients_train_files
        datadict["IC10_clients_test_files"] = IC10_clients_test_files
        
    if args.num_yahoo_topic_classification_clients > 0:
        yahoo_clients_train_files = [
            os.path.join(args.data_dir, "TC", "yahoo_qa", "train", f"client_{i}.pkl")
            for i in range(args.num_yahoo_topic_classification_clients)
        ]
        yahoo_clients_test_files = [
            os.path.join(args.data_dir, "TC", "yahoo_qa", "test", f"client_{i}.pkl")
            for i in range(args.num_yahoo_topic_classification_clients)
        ]
        datadict["yahoo_clients_train_files"] = yahoo_clients_train_files
        datadict["yahoo_clients_test_files"] = yahoo_clients_test_files

    if args.public_dataset_name == "coco":
        pub_img_folder = os.path.join(args.data_dir, "MLC_setup", "public")
    elif args.public_dataset_name == "flicker_multi_modal":
        pub_img_folder = os.path.join(args.data_dir, "TC", "flicker_public", "images")
        captions_path = os.path.join(args.data_dir, "TC", "flicker_public", "image_file_name_to_caption.pkl")
    elif args.public_dataset_name == "CIFAR100":
        pub_img_folder = os.path.join(args.data_dir, "IC", "CIFAR100", "public")
    elif args.public_dataset_name == "CIFAR10":
        pub_img_folder = os.path.join(args.data_dir, "IC", "CIFAR10", "public")
    elif args.public_dataset_name == "pascal":
        pub_img_folder = os.path.join(args.data_dir, "MLC_setup", "public_Pascal")
    else:
        raise ValueError("Unsupported publi_dataset_name")

    transform_public = transforms.Compose(
        [
            transforms.Resize((224, 224), antialias=True),
            transforms.ToTensor(),
        ]
    )

    if args.public_dataset_name in ["coco_multi_modal", "flicker_multi_modal"]:
        public_dataset = MultiModalDataset(
            folder_path=pub_img_folder, captions_path=captions_path, transform=transform_public
        )
    else:
        public_dataset = UnlabeledImageDataset(
        folder_path=pub_img_folder, transform=transform_public
    )
    

    public_loader = DataLoader(
        dataset=public_dataset,
        batch_size=args.pub_batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=True,
    )

    logging.info(f"| Public dataset examples: {len(public_dataset)}")

    datadict["public_loader"] = public_loader
    datadict["public_dataset"] = public_dataset

    return datadict


def log_dataset_stats(clients_list: List[Client]) -> None:
    """Prints the stats of clients dataset"""

    logging.info(f"\n####### | Dataset stats: ")

    total_train = 0
    total_test = 0

    for idx, client in enumerate(clients_list):
        size_train = len(client.train_dataset)
        total_train += size_train

        size_test = len(client.test_dataset)
        total_test += size_test
        logging.info(f"| Train examples for client {idx} with task {client.task}: {size_train}")
        logging.info(f"| Test examples for client {idx} with task {client.task}: {size_test}")

    logging.info(f"| Total train examples: {total_train}")
    logging.info(f"| Total test examples: {total_test}")

    logging.info(f"#######\n")
