import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
import urllib.request
import zipfile
import shutil


def download_and_extract_omniglot(root_dir):
    omniglot_url = "https://github.com/brendenlake/omniglot/raw/master/python"
    zip_files = ["images_background.zip", "images_evaluation.zip"]

    dataset_dir = os.path.join(root_dir, "omniglot")
    os.makedirs(dataset_dir, exist_ok=True)

    for zip_file in zip_files:
        zip_path = os.path.join(dataset_dir, zip_file)

        if not os.path.exists(zip_path):
            url = f"{omniglot_url}/{zip_file}"
            urllib.request.urlretrieve(url, zip_path)

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(dataset_dir)

    for folder in ["images_background", "images_evaluation"]:
        src = os.path.join(dataset_dir, "omniglot", folder)
        dst = os.path.join(dataset_dir, folder)
        if os.path.exists(src):
            shutil.move(src, dst)
    if os.path.exists(os.path.join(dataset_dir, "omniglot")):
        shutil.rmtree(os.path.join(dataset_dir, "omniglot"))


class OmniglotFullDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.alphabet_classes = []

        all_alphabets_dir = [
            os.path.join(root_dir, "omniglot", "images_background"),
            os.path.join(root_dir, "omniglot", "images_evaluation"),
        ]

        for split_dir in all_alphabets_dir:
            if not os.path.exists(split_dir):
                raise FileNotFoundError(
                    f"Not found: {split_dir}"
                )

            alphabet_names = sorted(
                [
                    d
                    for d in os.listdir(split_dir)
                    if os.path.isdir(os.path.join(split_dir, d))
                ]
            )
            for alphabet_name in alphabet_names:
                if alphabet_name not in self.alphabet_classes:
                    self.alphabet_classes.append(alphabet_name)

        alphabet_to_label = {
            name: i for i, name in enumerate(self.alphabet_classes)
        }

        for split_dir in all_alphabets_dir:
            for alphabet_name in os.listdir(split_dir):
                if os.path.isdir(os.path.join(split_dir, alphabet_name)):
                    alphabet_path = os.path.join(split_dir, alphabet_name)
                    label = alphabet_to_label[alphabet_name]

                    for char_name in os.listdir(alphabet_path):
                        char_path = os.path.join(alphabet_path, char_name)
                        for img_name in os.listdir(char_path):
                            img_path = os.path.join(char_path, img_name)
                            self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, label


def get_resnet18_model(num_classes):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

    model.conv1 = nn.Conv2d(
        1, 64, kernel_size=7, stride=2, padding=3, bias=False
    )

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    return model


def get_resnet18_single_channel(num_classes):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

    model.conv1 = nn.Conv2d(
        1, 64, kernel_size=7, stride=2, padding=3, bias=False
    )

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    return model


def main():
    data_root = "./data"
    download_and_extract_omniglot(data_root)

    BATCH_SIZE = 256
    LEARNING_RATE = 0.003
    EPOCHS = 200
    NUM_ALPHABETS = 50
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_transform = transforms.Compose(
        [
            transforms.RandomCrop((64, 64)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
        ]
    )

    valid_transform = transforms.Compose(
        [
            transforms.CenterCrop((64, 64)),
            transforms.ToTensor(),
        ]
    )

    full_dataset = OmniglotFullDataset(root_dir=data_root, transform=None)

    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size]
    )

    train_dataset = TransformDataset(train_dataset, train_transform)
    val_dataset = TransformDataset(val_dataset, valid_transform)

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True
    )
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = get_resnet18_single_channel(num_classes=NUM_ALPHABETS).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(
            f"Ep {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader):.4f}"
        )

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f"Valid Acc : {accuracy:.2f}%\n")

        # jit model save
        model_jit = torch.jit.trace(model, images)
        model_jit.save(f"resnet18_omniglot_{accuracy:.2f}_jit.pth")


class TransformDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


if __name__ == "__main__":
    main()
