import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
import time
import math
import pandas as pd
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

transform_test = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

import os
import shutil

def organize_val_dataset():
    val_dir = './tiny-imagenet-200/val'
    img_dir = os.path.join(val_dir, 'images')
    ann_file = os.path.join(val_dir, 'val_annotations.txt')

    # Read the annotations file
    with open(ann_file, 'r') as f:
        data = f.readlines()

    # Create a dictionary to map image filenames to their labels
    val_img_dict = {}
    for line in data:
        words = line.split('\t')
        val_img_dict[words[0]] = words[1]

    # Make class subdirectories and move images
    for img, cls in val_img_dict.items():
        cls_dir = os.path.join(val_dir, cls)
        if not os.path.exists(cls_dir):
            os.makedirs(cls_dir)
        src = os.path.join(img_dir, img)
        dst = os.path.join(cls_dir, img)
        shutil.move(src, dst)

    # Remove the images directory
    shutil.rmtree(img_dir)
organize_val_dataset()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = torchvision.datasets.ImageFolder(
    root='./tiny-imagenet-200/train',
    transform=transform_train
)

test_dataset = torchvision.datasets.ImageFolder(
    root='./tiny-imagenet-200/val',
    transform=transform_test
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=4
)