import os
import random
import torch
import numpy as np
import pandas as pd
from torchvision.transforms import v2
from PIL import ImageDraw
from dataset.handlers import FlickrDataset, TwitterDataset, RAFDataset, Emotion6Dataset, FBP5500Dataset
 

HANDLER_DICT = {
    'flickr': FlickrDataset,
    'twitter': TwitterDataset,
    'raf': RAFDataset,
    'emotion6': Emotion6Dataset,
    'fbp5500': FBP5500Dataset
}


def load_data(args):
    data = {}
    phases = ['train_label', 'train_unlabel', 'val', 'test']
    dataset_dir = args.dataset_dir
    for phase in phases:
        raw_data = pd.read_csv(os.path.join(dataset_dir, phase+'_data.csv'))
        data[phase] = {}
        data[phase]['labels'] = raw_data.iloc[:, 1:].values
        data[phase]['images'] = raw_data.iloc[:, 0].values
    return data


def get_datasets(args):
    transform_train = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        v2.RandomHorizontalFlip(0.5),
        v2.RandomAffine(
            degrees=0,
            translate=(0.125, 0.125)
        ),
        v2.ToTensor()]
    )

    val_transform = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        v2.ToTensor()])
    
    data = load_data(args)
    data_handler = HANDLER_DICT[args.dataset_name]
    train_label_dataset = data_handler(data['train_label']['images'], data['train_label']['labels'], args.dataset_dir, transform=transform_train)
    train_unlabel_dataset = data_handler(data['train_unlabel']['images'], data['train_unlabel']['labels'], args.dataset_dir, transform=transform_train)
    val_dataset = data_handler(data['val']['images'], data['val']['labels'], args.dataset_dir, transform=val_transform)
    test_dataset = data_handler(data['test']['images'], data['test']['labels'], args.dataset_dir, transform=val_transform)
    
    return train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset
