import os
import random
import torch
import numpy as np
import pandas as pd
from randaugment import RandAugment, Cutout
from torchvision.transforms import v2
from PIL import ImageDraw
from dataset.handlers import FlickrDataset, TwitterDataset, RAFDataset, Emotion6Dataset, FBP5500Dataset
    


class CutoutPIL(object):
    def __init__(self, cutout_factor=0.5):
        self.cutout_factor = cutout_factor

    def __call__(self, x):
        img_draw = ImageDraw.Draw(x)
        h, w = x.size[0], x.size[1]  # HWC
        h_cutout = int(self.cutout_factor * h + 0.5)
        w_cutout = int(self.cutout_factor * w + 0.5)
        y_c = np.random.randint(h)
        x_c = np.random.randint(w)

        y1 = np.clip(y_c - h_cutout // 2, 0, h)
        y2 = np.clip(y_c + h_cutout // 2, 0, h)
        x1 = np.clip(x_c - w_cutout // 2, 0, w)
        x2 = np.clip(x_c + w_cutout // 2, 0, w)
        fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        img_draw.rectangle([x1, y1, x2, y2], fill=fill_color)

        return x

class TransformUnlabeled_WS(object):
    def __init__(self, args):
        self.weak = v2.Compose([
            v2.Resize((args.img_size, args.img_size)),
			v2.RandomHorizontalFlip(0.5),
            v2.RandomAffine(
                degrees=0,
                translate=(0.125, 0.125)
            ),
			v2.ToTensor()])

        self.strong = v2.Compose([
			v2.Resize((args.img_size, args.img_size)),
			# CutoutPIL(cutout_factor=0.5),
			v2.RandAugment(),
            Cutout(size=16),
			# transforms.RandomApply([
            #     transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  
            # ], p=0.8),
            # transforms.RandomGrayscale(p=0.2),
			v2.ToTensor()])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return weak, strong
    
class TransformUnlabeled_Wlabel(object):
    def __init__(self, args):
        self.weak = v2.Compose([
            v2.Resize((args.img_size, args.img_size)),
			v2.RandomHorizontalFlip(0.5),
            v2.RandomAffine(
                degrees=0,
                translate=(0.125, 0.125)
            ),
			v2.ToTensor()])
        
        self.norm = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        v2.ToTensor()])

    def __call__(self, x):
        weak1 = self.weak(x)
        weak2 = self.weak(x)
        weak3 = self.weak(x)
        return self.norm(x), weak1, weak2, weak3
    

HANDLER_DICT = {
    'flickr': FlickrDataset,
    'twitter': TwitterDataset,
    'raf': RAFDataset,
    'emotion6': Emotion6Dataset,
    'fbp5500': FBP5500Dataset
}


def load_data(args):
    data = {}
    phases = ['train_label', 'train_unlabel', 'val', 'test']
    dataset_dir = args.dataset_dir
    if args.train_unlabel:
        phases.append('wlabel')
    for phase in phases:
        raw_data = pd.read_csv(os.path.join(dataset_dir, phase+'_data.csv'))
        data[phase] = {}
        data[phase]['labels'] = raw_data.iloc[:, 1:].values
        data[phase]['images'] = raw_data.iloc[:, 0].values
    return data


def get_datasets(args):
    train_transform = TransformUnlabeled_WS(args)
    weak_label_transform = TransformUnlabeled_Wlabel(args)

    val_transform = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        v2.ToTensor()])

    test_transform = v2.Compose([
        v2.Resize((args.img_size, args.img_size)),
        v2.ToTensor()])
    
    data = load_data(args)
    data_handler = HANDLER_DICT[args.dataset_name]
    if args.train_ensemble:
        train_label_dataset = data_handler(data['train_label']['images'], data['train_label']['labels'], args.dataset_dir, transform=weak_label_transform)
    else:
        train_label_dataset = data_handler(data['train_label']['images'], data['train_label']['labels'], args.dataset_dir, transform=train_transform)
    train_unlabel_dataset = data_handler(data['train_unlabel']['images'], data['train_unlabel']['labels'], args.dataset_dir, transform=train_transform)
    val_dataset = data_handler(data['val']['images'], data['val']['labels'], args.dataset_dir, transform=val_transform)
    test_dataset = data_handler(data['test']['images'], data['test']['labels'], args.dataset_dir, transform=test_transform)
    weak_setlabel_dataset = data_handler(data['train_unlabel']['images'], data['train_unlabel']['labels'], args.dataset_dir, transform=weak_label_transform)

    if not args.train_unlabel:
        return train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset, weak_setlabel_dataset

    
    weak_label_dataset = data_handler(data['wlabel']['images'], data['train_unlabel']['labels'], args.dataset_dir, transform=train_transform)
    
    return train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset, weak_setlabel_dataset, weak_label_dataset
