import logging
import numpy as np
import os
import torch
import json

from pathlib import Path
from pytorch_lightning import seed_everything
from torchvision.transforms import (
    CenterCrop,
    ColorJitter,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)
from torch.utils.data import Dataset
from glob import glob
from collections import OrderedDict
from PIL import Image

IMAGENET100_CLASSES = [
    "n01968897",
    "n01770081",
    "n01818515",
    "n02011460",
    "n01496331",
    "n04347754",
    "n01687978",
    "n01740131",
    "n01537544",
    "n01491361",
    "n02007558",
    "n01735189",
    "n01630670",
    "n01440764",
    "n01819313",
    "n02002556",
    "n01667778",
    "n01755581",
    "n01924916",
    "n01751748",
    "n01984695",
    "n01729977",
    "n01614925",
    "n01608432",
    "n01443537",
    "n01770393",
    "n01855672",
    "n01560419",
    "n01592084",
    "n01914609",
    "n01582220",
    "n01667114",
    "n01784675",
    "n01820546",
    "n01773797",
    "n02006656",
    "n01986214",
    "n01484850",
    "n01749939",
    "n01828970",
    "n02018795",
    "n01695060",
    "n01729322",
    "n01677366",
    "n01734418",
    "n01843383",
    "n01806143",
    "n01773549",
    "n01775062",
    "n01728572",
    "n01601694",
    "n01978287",
    "n01930112",
    "n01739381",
    "n01883070",
    "n01774384",
    "n02037110",
    "n01795545",
    "n02027492",
    "n01531178",
    "n01944390",
    "n01494475",
    "n01632458",
    "n01698640",
    "n01675722",
    "n01877812",
    "n01622779",
    "n01910747",
    "n01860187",
    "n01796340",
    "n01833805",
    "n01685808",
    "n01756291",
    "n01514859",
    "n01753488",
    "n02058221",
    "n01632777",
    "n01644900",
    "n02018207",
    "n01664065",
    "n02028035",
    "n02012849",
    "n01776313",
    "n02077923",
    "n01774750",
    "n01742172",
    "n01943899",
    "n01798484",
    "n02051845",
    "n01824575",
    "n02013706",
    "n01955084",
    "n01773157",
    "n01665541",
    "n01498041",
    "n01978455",
    "n01693334",
    "n01950731",
    "n01829413",
    "n02093859",
]

IMAGENET100_CONCEPTS = [
    "a bait bucket",
    "flippers instead of legs",
    "a brightly colored face",
    "a medium-sized body",
    "a large, white bird",
    "a long, black neck",
    "a gull",
    "cacti",
    "a sky",
    "dunes",
    "a swimmer",
    "two long, antennae",
    "brown and white plumage",
    "a long, furry body",
    "a knife",
    "a large, flat beak",
    "a white tail",
    "a small opening at one end",
    "long, tentacles hanging down",
    "two long, curved fangs",
    "a large, oval-shaped body",
    "long hind legs for leaping",
    "pointed wings",
    "a lawn",
    "a heater",
    "other sharks",
    "leaves",
    "a body of water",
    "a diamond-patterned back",
    "a rodent",
    "pink or reddish feathers",
    "yellow or white belly",
    "a bell-shaped body",
    "a long, tailless hind end",
    "living in arid climates",
    "fruit",
    "mollusc",
    "legs that are long and thin",
    "a white rump",
    "brown and white feathers",
    "a reel",
    "wide head",
    "beekeeping",
    "a long, sharp beak",
    "amphibian",
    "a hard, leathery shell",
    "sharp talons and beak",
    "a small, plump body",
    "a sonar",
    "animal",
    "a give",
    "vertebrate",
    "a net",
    "a oval-shaped body",
    "a small, bird-like body",
    "large ears",
    "a small, brown shell",
    "a long, thin legs",
    "black legs",
    "strong legs and feet",
    "long, red legs",
    "a hard shell",
    "a water",
    "the sea",
    "a dark brown body",
    "usually found in shallow water",
    "two small eyes",
    "feed",
    "a garter",
    "a soft, translucent body",
    "green color",
    "a crab trap",
    "a long neck",
    "a pointed snout",
    "barbels around the mouth",
    "water",
    "crab pots",
    "a bee suit",
    "a small, sparrow-like bird",
    "a forest",
    "a dull color",
    "a grey or white body",
    "a small, dark-colored body",
    "a branch",
    "often brightly colored",
    "a hiding place",
    "a long, tapered tail",
    "bright plumage",
    "a sunbeam",
    "a shotgun",
    "a predator",
    "eight legs",
    "a long, whip-like tail",
    "a pair of trousers",
    "beeswax",
    "a baby stork",
    "cnidarian",
    "Mammal",
    "marsupial",
    "a set of wings",
    "a fishbowl",
    "can be aggressive",
    "a plump body",
    "arachnid",
    "a prey",
    "quick movements",
    "a tree",
    "a belt at the waist",
    "sugar water",
    "a ray-shaped mouth",
    "a small body",
    "a brown, black, or green color",
    "smooth, shiny scales",
    "a meadow",
    "a blade",
    "a peahen",
    "a black cap on the head",
    "a reedy call",
    "a nut",
    "a rattle",
    "a white border",
    "a shore",
    "black or brown scales",
    "a venomous bite",
    "a loud, cackling call",
    "a prey animal",
    "living thing",
    "a small, spiral shell",
    "a pointed beak",
    "a bright yellow color",
    "logs",
    "a long, snake-like shape",
    "a pointed tongue",
    "dark stripes on the sides",
    "a long, slender neck",
    "periscopes",
    "worm",
    "a birdbath",
    "a translucent body",
    "a curved bill",
    "a hole",
    "a webbed foot",
    "sharp claws",
    "powerful fins",
    "a home",
    "a tall, pink bird",
    "a large, muscular body",
    "a green color",
    "a crab fork",
    "a dark brown back",
    "a wormy appearance",
    "a small to medium-sized bird",
    "skin",
    "a silky, feathered coat",
    "a marsupial pouch",
    "a dark-colored carapace",
    "propellers",
    "white feathers",
    "legs with spurs",
    "the Sahara Desert",
    "a long, straight bill",
    "a lizard-like body",
    "a white or gray color",
    "gecko",
    "bacteria",
    "grayish upperparts",
    "paddle-like fins",
    "a pond",
    "a treat",
    "a brightly colored body",
    "a small, down-turned bill",
    "insects",
    "a hard, dark-colored shell",
    "a wet suit",
    "a food dish",
    "a location near water",
    "usually has a dark color",
    "a cave",
    "a reed",
    "algae",
    "a swan",
    "white stripes on the wings",
    "twigs",
    "four legs",
    "flying vertebrate",
    "adults",
    "a toad",
    "a small head and thorax",
    "a long, dense coat",
    "dog",
    "a beehive",
    "a log",
    "a park",
    "a crab mallet",
    "a long, prehensile tongue",
    "a long, curved beak",
    "a long tail",
    "eyes at the side of the head",
    "a rabbit",
    "a camera",
    "powerful swimming ability",
    "a steamer",
    "a triangular head",
    "a feeder",
    "a long, conspicuous tail",
    "heavy facial discs",
    "a lifeguard",
    "red, blue, and yellow feathers",
    "a crab cracking tool",
    "small size",
    "a white belly",
    "a sedentary lifestyle",
    "a leaves",
    "a large body",
    "a rock",
    "a fleshy mantle",
    "variable colors",
    "a small, pointed tail",
    "birds",
    "a long, thick tail",
    "organism",
    "a chicken coop",
    "talons",
    "orange bill and legs",
    "a birdwatcher",
    "a black belly patch",
    "viper",
    "smooth, shiny skin",
    "reptiles",
    "a turtle shell",
    "claws on its legs",
    "a stone",
    "smooth scales",
    "a narrow head",
    "chicks",
    "pointed barbs on the tail",
    "a small, segmented body",
    "an oyster",
    "white cheeks",
    "a black body",
    "on an animal",
    "a crew",
    "a hooked bill",
    "food",
    "a jelly-like texture",
    "insectivore",
    "a wide, triangular head",
    "long, thin legs",
    "black wings with white stripes",
    "a pool",
    "long, greenish-yellow legs",
    "a hay bale",
    "no shell",
    "a variety of colors",
    "a windy day",
    "a shrimp",
    "a stick",
    "a marsh",
    "a high, shrill call",
    "a conning tower",
    "orange fins",
    "a sandpiper",
    "a gray or blue color",
    "a bird of prey",
    "a baby cockatoo",
    "a long, thin tail",
    "a white breast",
    "a small, forked tail",
    "a large, coiled shell",
    "a large, cigar-shaped body",
    "a long, thin abdomen",
    "a tide pool",
    "a short tail",
    "pale gray or blue color",
    "a white underbelly",
    "wings that move very fast",
    "seals",
    "grayish color",
    "a dish",
    "branches",
    "a black head",
    "a light",
    "a nest",
    "a field",
    "seagrass",
    "short, powerful legs",
    "a scaly skin",
    "a long, black-and-white tail",
    "a soft, slug-like body",
    "iridescent feathers",
    "rattles",
    "black wings with white bars",
    "a rod",
    "a dark coloration",
    "a torpedo",
    "strong legs",
    "a dark color (usually black)",
    "Australia",
    "a round face",
    "a head with a beak",
    "chordate",
    "chickens",
    "a desert animal",
    "yellow feet",
    "a long, forked tongue",
    "a large, bulbous abdomen",
    "a water bowl",
    "a flared lip",
    "a soft body",
    "a flag",
    "short legs",
    "a bush",
    "a snow",
    "a sea",
    "a male grouse",
    "ice",
    "a mouse",
    "a long,neck",
    "a black stripe on the face",
    "sand",
    "a trees",
    "a white bill",
    "a zigzag pattern on the back",
    "a blue body",
    "relatively small wings",
    "a small head at the other end",
    "a slightly pointed end",
    "a long, slender shape",
    "long, narrow wings",
    "the ocean",
    "a short beak",
    "a hard, protective outer shell",
    "a fence",
    "short, round wings",
    "a large, round body",
    "vines",
    "a water pump",
    "a thick, strong neck",
    "a seal",
    "a wedge-shaped head",
    "a warm climate",
    "ocean",
    "a white or pale color",
    "a small, flattened body",
    "salamander",
    "long, orange legs",
    "eggs",
    "a machine",
    "a long, curved neck",
    "a snake charmer",
    "good eyesight",
    "large wings",
    "vertical pupils",
    "a upturned nose",
    "terrier",
    "a small songbird",
    "a small, lightweight body",
    "a long, blunt head",
    "two large, curved fangs",
    "rocks",
    "a long, torpedo-shaped body",
    "two legs",
    "a pink or white color",
    "a venom",
    "a long, thin body",
    "a ship",
    "a black or dark color",
    "a long, thick neck",
    "a hide box",
    "a bed of straw",
    "many legs",
    "grasses",
    "a small mammal",
    "a thick, pearly inside",
    "red or brown color",
    "large, webbed hind feet",
    "tetradactyla",
    "short, flippers for arms",
    "a long, feathered tail",
    "a large, tooth-filled beak",
    "a large, square head",
    "large, triangular fins",
    "cattails",
    "a small to medium size",
    "a long, barbed tail",
    "a crab pot",
    "a turtle",
    "reptile",
    "a baby",
    "large fins",
    "on a plant",
    "a grass",
    "reeds",
    "a green or olive color",
    "a beak",
    "no patterns or markings",
    "a small, squat body",
    "cactus",
    "a strike",
    "a smooth, green skin",
    "A collar",
    "a small mouth",
    "thick, heavy shell",
    "a cygnet",
    "no head or limbs",
    "a small, oblong shape",
    "a animal",
    "large horns",
    "a protruding crest on the head",
    "a cute, furry face",
    "a waterer",
    "a white wingbar",
    "a large, colorful bird",
    "bird seed",
    "bait",
    "white wingbars",
    "a vole",
    "a white or gray plumage",
    "scales",
    "a dog",
    "short, front legs",
    "no legs",
    "a large body size",
    "a leaf",
    "a jungle",
    "a plant",
    "a tweezers",
    "a farmer",
    "a long, red bill",
    "a black nose",
    "feathers",
    "fish",
    "a long, thick body",
    "a human",
    "a small head",
    "two small eyes on stalks",
    "a wide, flat head",
    "a aquarium",
    "grass",
    "black or brown stripes",
    "a thermometer",
    "a basking platform",
    "white underparts",
    "a cute, pug-like face",
    "a diving mask",
    "A leash",
    "egg",
    "a white throat",
    "a vine",
    "a lot of noise",
    "smooth skin",
    "short, stubby legs",
    "a short, black bill",
    "a long, slender body",
    "a hard, oval-shaped shell",
    "a brown or red color",
    "large flippers",
    "a dorsal fin",
    "good night vision",
    "eight long legs",
    "a white chest",
    "a periscope",
    "five pairs of legs",
    "a small, round shape",
    "dirt",
    "seagulls",
    "a small, stocky shorebird",
    "a cage",
    "a long, curved body",
    "a large, bulky body",
    "a kite stick",
    "two long, pedipalps (or fangs)",
    "a carnivorous diet",
    "a filter",
    "marsh plants",
    "a stocky body",
    "seafood",
    "a heat lamp",
    "a flat, circular shape",
    "large scales",
    "a black and white color scheme",
    "a fruit",
    "crocodilian",
    "brightly colored feathers",
    "a large, rounded shape",
    "large eyes",
    "dust",
    "a bird bath",
    "a segmented abdomen",
    "a short, stout body",
    "thick, woolly fur",
    "a perch",
    "tentacles",
    "a white color",
    "attached to a surface",
    "a pot",
    "fish food",
    "a nature reserve",
    "a grassland",
    "a strong, muscular body",
    "a large pair of pincers",
    "a white beak",
    "a river",
    "a long, thin snout",
    "a pointed face",
    "carnivore",
    "echinoderm",
    "two long pedipalps",
    "a red background",
    "a lack of appendages",
    "white or gray plumage",
    "large, round eyes",
    "a burrow",
    "a curved beak",
    "a wide opening",
    "a long, green body",
    "watercraft",
    "fungus",
    "pouch for carrying young",
    "a light color",
    "the United States",
    "a surfboard",
    "a eucalyptus tree",
    "long grass",
    "frogs",
    "a large, spiral-shaped shell",
    "a seed",
    "a captain",
    "a tall, slender body",
    "a desert",
    "a fork",
    "a long, thin beak",
    "a bird feeder",
    "arthropod",
    "a brown or grayish color",
    "fins on its back and tail",
    "a mottled brown plumage",
    "a person",
    "a hunter",
    "a white head and tail",
    "Shells",
    "two large claws",
    "under a rock",
    "a small, slimy body",
    "a decoration",
    "mud",
    "short, flipper-like limbs",
    "a flipper",
    "a garden",
    "a brown or white color",
    "a keeper",
    "a water dish",
    "a small, crab-like body",
    "a grouse",
    "A yard",
    "yellow feathers",
    "a tail with long feathers",
    "snakes",
    "a cactus",
    "a shark cage",
    "a large, stocky bird",
    "a kite string",
    "a white cheek",
    "owl pellets",
    "a horny, beak-like mouth",
    "a toy",
    "a lizard",
    "a rainforest",
    "relatively slow movement",
    "tundra",
    "a long, curved bill",
    "two small, compound eyes",
    "a smooth, unsegmented body",
    "a fishing rod",
    "usually found outdoors",
    "large, protruding eyes",
    "a warm place to hide",
    "a green or brown color",
    "a wing",
    "a fishing line",
    "object",
    "fang-like mouthparts",
    "a birdcage",
    "a small, hard shell",
    "a small, worm-like body",
    "a bill",
    "a exoskeleton",
    "hollow fangs",
    "loud, harsh call",
    "greenish-brown color",
    "a vibrant blue color",
    "a soft, slimy body",
    "a lake",
    "long, orange beak",
    "Eucalyptus trees",
    "able to fly long distances",
    "short, fur-covered legs",
    "a colorful exterior",
    "a mate",
    "long legs",
    "large wingspan",
    "a woods",
    "a stream",
    "a central mouth",
    "a large, powerful jaw",
    "a hooked beak",
    "plants",
    "a reddish-brown body",
    "soil",
    "a brown or black color",
    "crabs",
    "a large mouth",
    "a necktie",
    "feathers at the other end",
    "large, sharp teeth",
    "gray or white feathers",
    "a wave",
    "legs",
    "darkness",
    "flowers",
    "a blue-gray color",
    "in water",
    "an aquarium",
    "in the soil",
    "a body",
    "a moist, smooth skin",
    "a pair of long, antennae",
    "trees",
    "brown or black eyes",
    "a water filter",
    "a fish",
    "decaying matter",
    "a bait",
    "a black and white body",
    "the night",
    "yellow eyes",
    "a slender body",
    "a small, round body",
    "a bright green color",
    "a blood meal",
    "a brown, gray, or olive color",
    "a black cap and bib",
    "long, arms and legs",
    "a yard",
    "a large, oval shape",
    "a light body",
    "a chicken",
    "honey",
    "a big head",
    "powerful wings",
    "long, sharp claws",
    "a house",
    "insect eater",
    "a pump",
    "a small, sturdy body",
    "a shell",
    "a feather",
    "a swamp",
    "mollusk",
    "a zoo",
    "a dark olive-brown coloration",
    "invertebrate",
    "gray wings and back",
    "other birds",
    "fins",
    "large, dark eyes",
    "bushes",
    "seaweed",
    "a large size",
    "a string",
    "two long, feeler-like antennae",
    "black plumage",
    "a wingspan of 3-4 feet",
    "a long, thin neck",
    "a small, rounded body",
    "two small tentacles",
    "a wetland",
    "a small, slim body",
    "a shark fin",
    "sea turtle",
    "no antennae",
    "people",
    "an octagonal shape",
    "vessel",
    "a flower",
    "a hunched posture",
    "nectar",
    "slow moving",
    "a bug",
    "a worm",
    "a green or brown body",
    "a long beak",
    "a loud, cuckoo-like call",
    "a prairie",
    "a surfer",
    "black wingtips",
    "a gander",
    "a flat body",
    "no visible eyes or antennae",
    "a fin",
]

N_CONCEPTS = len(IMAGENET100_CONCEPTS)
N_CLASSES = len(IMAGENET100_CLASSES)

CONCEPTS_PATH = (
    "/raid/ai24mtech12011/projects/temp/fca4nn/DATA/concepts/inet100_concept_matrix.npy"
)


class Imagenet100ConceptDataset(Dataset):
    def __init__(
        self,
        data_dir,
        split="train",
        transform=None,
    ):
        self.num_classes = N_CLASSES
        self.split = split
        self.class_list = sorted(IMAGENET100_CLASSES)
        self.concept_list = IMAGENET100_CONCEPTS
        self.attr_npy = np.load(CONCEPTS_PATH)

        self.dir_idx = {k: v for v, k in enumerate(self.class_list)}

        if transform is None:
            if split == "train":
                self.transforms = Compose(
                    [
                        RandomResizedCrop(224, interpolation=Image.BILINEAR),
                        RandomHorizontalFlip(),
                        ColorJitter(
                            brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                        ),
                        ToTensor(),
                        Normalize(
                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                        ),
                    ]
                )
            else:
                self.transforms = Compose(
                    [
                        Resize(size=256, interpolation=Image.BILINEAR),
                        CenterCrop(224),
                        ToTensor(),
                        Normalize(
                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                        ),
                    ]
                )
        else:
            self.transforms = transform

        self.data = []
        for d in self.class_list:
            label = self.dir_idx[d]
            if split == "train":
                images = glob(os.path.join(data_dir, "train", d, "*.JPEG"))
            elif split == "val":
                images = glob(os.path.join(data_dir, "val", d, "*.JPEG"))
            elif split == "test":
                images = glob(os.path.join(data_dir, "test_set", d, "*.JPEG"))
            else:
                raise ValueError("Invalid split: {}".format(split))
            self.data.extend(list(zip(images, [label] * len(images))))

        np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def get_concept_count(self):
        return N_CONCEPTS

    def __getitem__(self, index):
        image_path, label = self.data[index]
        image = Image.open(image_path).convert("RGB")
        image = self.transforms(image)
        concept_vector = self.attr_npy[label]

        return image, label, torch.FloatTensor(concept_vector.astype(np.float32))


def generate_data(
    config,
    root_dir,
    seed=42,
    output_dataset_vars=False,
    rerun=False,
):
    concept_group_map = None
    seed_everything(seed)

    # Load the dataset
    train_dataset = Imagenet100ConceptDataset(
        data_dir=root_dir,
        split="train",
        transform=None,
    )
    val_dataset = Imagenet100ConceptDataset(
        data_dir=root_dir,
        split="val",
        transform=None,
    )
    test_dataset = Imagenet100ConceptDataset(
        data_dir=root_dir,
        split="test",
        transform=None,
    )
    # Create data loaders
    train_dl = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=config["num_workers"],
    )
    val_dl = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
    )
    test_dl = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
    )

    # Create imbalance
    if config.get("weight_loss", False):
        attribute_count = np.zeros((N_CONCEPTS,))
        samples_seen = 0
        for i in range(len(train_dataset)):
            _, _, attribute = train_dataset[i]
            attribute_count += attribute.numpy()
            samples_seen += 1
        imbalance = samples_seen / attribute_count - 1
    else:
        imbalance = None

    if not output_dataset_vars:
        return train_dl, val_dl, test_dl, imbalance
    return (
        train_dl,
        val_dl,
        test_dl,
        imbalance,
        (N_CONCEPTS, N_CLASSES, concept_group_map),
    )
