import os
import random
import torch
import numpy as np
import pandas as pd
from randaugment import RandAugment, Cutout
from torchvision.transforms import v2
from PIL import ImageDraw
from dataset.handlers import FlickrDataset, TwitterDataset, RAFDataset, Emotion6Dataset, FBP5500Dataset


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2


def pad(x, border=4):
    return np.pad(x, [(0, 0), (border, border), (border, border)], mode='reflect')


class RandomPadandCrop(object):
    """Crop randomly the image.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, x):

        x = np.array(x).transpose(2, 0, 1)
        x = pad(x, 10)

        h, w = x.shape[1:]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        x = x[:, top: top + new_h, left: left + new_w]

        x = np.array(x).transpose(1, 2, 0)

        return x


class GaussianNoise(object):
    """Add gaussian noise to the image.
    """
    def __call__(self, x):
        c, h, w = x.shape
        x += np.random.randn(c, h, w) * 0.15
        return x

class ToTensor(object):
    """Transform the image to tensor.
    """
    def __call__(self, x):
        x = torch.from_numpy(x)
        return x
    

HANDLER_DICT = {
    'flickr': FlickrDataset,
    'twitter': TwitterDataset,
    'raf': RAFDataset,
    'emotion6': Emotion6Dataset,
    'fbp5500': FBP5500Dataset
}


def load_data(args):
    data = {}
    phases = ['train_label', 'train_unlabel', 'val', 'test']
    dataset_dir = args.dataset_dir
    if args.train_unlabel:
        phases.append('wlabel')
    for phase in phases:
        raw_data = pd.read_csv(os.path.join(dataset_dir, phase+'_data.csv'))
        data[phase] = {}
        data[phase]['labels'] = raw_data.iloc[:, 1:].values
        data[phase]['images'] = raw_data.iloc[:, 0].values
    return data


def get_datasets(args):
    transform_train = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        RandomPadandCrop(args.img_size),
        v2.RandomHorizontalFlip(),
        v2.ToTensor(),
    ])

    val_transform = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        v2.ToTensor()])
    
    data = load_data(args)
    data_handler = HANDLER_DICT[args.dataset_name]
    train_label_dataset = data_handler(data['train_label']['images'], data['train_label']['labels'], args.dataset_dir, transform=transform_train)
    train_unlabel_dataset = data_handler(data['train_unlabel']['images'], data['train_unlabel']['labels'], args.dataset_dir, transform=TransformTwice(transform_train))
    val_dataset = data_handler(data['val']['images'], data['val']['labels'], args.dataset_dir, transform=val_transform)
    test_dataset = data_handler(data['test']['images'], data['test']['labels'], args.dataset_dir, transform=val_transform)
    
    return train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset
