from pyexpat import model
from torchvision import datasets, transforms as T
from PIL import PngImagePlugin

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
import os, sys
import engine.models as models
import engine.utils as utils
from functools import partial
import numpy as np
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
import torchvision

NORMALIZE_DICT = {
    "cifar10": dict(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
    "cifar100": dict(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
    "cifar10_224": dict(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
    "cifar100_224": dict(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
    "cub200": dict(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
}


MODEL_DICT = {
    "resnet18": models.cifar.resnet.resnet18,
    "resnet34": models.cifar.resnet.resnet34,
    "resnet50": models.cifar.resnet.resnet50,
    "resnet101": models.cifar.resnet.resnet101,
    "resnet152": models.cifar.resnet.resnet152,
    "resnet20": models.cifar.resnet_tiny.resnet20,
    "resnet32": models.cifar.resnet_tiny.resnet32,
    "resnet44": models.cifar.resnet_tiny.resnet44,
    "resnet56": models.cifar.resnet_tiny.resnet56,
    "resnet110": models.cifar.resnet_tiny.resnet110,
}

IMAGENET_MODEL_DICT = {
    "resnet50": models.imagenet.resnet50,
    "resnet34": models.imagenet.resnet34,
}

import pandas as pd
from torch.utils.data import Dataset
from PIL import Image


class CUB(Dataset):
    def __init__(
        self,
        root: str,
        train: bool = True,
        data_len: int = None,
        transform=None,
        target_transform=None,
    ):
        super().__init__()
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        self._load_metadata(data_len)

    def _load_metadata(self, data_len: int):
        images = pd.read_csv(
            os.path.join(self.root, "images.txt"), sep=" ", names=["id", "path"]
        )
        labels = pd.read_csv(
            os.path.join(self.root, "image_class_labels.txt"),
            sep=" ",
            names=["id", "label"],
        )
        splits = pd.read_csv(
            os.path.join(self.root, "train_test_split.txt"),
            sep=" ",
            names=["id", "is_train"],
        )

        df = images.merge(labels, on="id").merge(splits, on="id")

        if self.train:
            self.metadata = df[df["is_train"] == 1].copy()
        else:
            self.metadata = df[df["is_train"] == 0].copy()

        if data_len is not None:
            self.metadata = self.metadata.iloc[:data_len]

        self.metadata["label"] -= 1

        self.metadata["full_path"] = self.metadata["path"].apply(
            lambda x: os.path.join(self.root, "images", x)
        )

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_path = self.metadata.iloc[idx]["full_path"]
        img = Image.open(img_path).convert("RGB")

        target = self.metadata.iloc[idx]["label"]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


def get_model(
    name: str, num_classes, pretrained=False, target_dataset="cifar100", **kwargs
):
    if "cifar" in target_dataset:
        model = MODEL_DICT[name](num_classes=num_classes)
    elif target_dataset == "cub200":
        model = IMAGENET_MODEL_DICT[name](pretrained=pretrained)
    else:
        raise NotImplementedError("Not supported dataset")
    return model


def get_dataset(
    name: str,
    data_root: str = "data",
    FPVE_fitness_data_ratio=0,
    return_transform=False,
    args=None,
):
    name = name.lower()
    if name == "cifar100":
        num_classes = 100
        train_transform = T.Compose(
            [
                T.RandomCrop(32, padding=4),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(**NORMALIZE_DICT[name]),
            ]
        )
        val_transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(**NORMALIZE_DICT[name]),
            ]
        )
        data_root = os.path.join(data_root, "torchdata")
        train_dst = datasets.CIFAR100(
            data_root, train=True, download=True, transform=train_transform
        )

        fitness_dst = datasets.CIFAR100(
            data_root, train=True, download=True, transform=val_transform
        )

        val_dst = datasets.CIFAR100(
            data_root, train=False, download=True, transform=val_transform
        )
        input_size = (1, 3, 32, 32)
    elif name == "cub200":
        num_classes = 200
        data_root = os.path.join(data_root, "CUB_200_2011")
        train_transform = T.Compose(
            [
                T.Resize(512),
                T.RandomResizedCrop(448),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(**NORMALIZE_DICT[name]),
            ]
        )
        val_transform = T.Compose(
            [
                T.Resize(512),
                T.CenterCrop(448),
                T.ToTensor(),
                T.Normalize(**NORMALIZE_DICT[name]),
            ]
        )

        train_dst = CUB(data_root, train=True, transform=train_transform)
        fitness_dst = CUB(data_root, train=True, transform=val_transform)
        val_dst = CUB(data_root, train=False, transform=val_transform)
        input_size = (1, 3, 448, 448)

    else:
        raise NotImplementedError

    if FPVE_fitness_data_ratio != 0:
        num_samples = int(len(train_dst) * FPVE_fitness_data_ratio)
        indices = np.arange(len(train_dst))
        np.random.shuffle(indices)
        selected_indices = indices[:num_samples]
        FPVE_fitness_data_dst = Subset(train_dst, selected_indices)
    else:
        FPVE_fitness_data_dst = None

    return num_classes, train_dst, val_dst, input_size, FPVE_fitness_data_dst
