from torchvision import transforms
from tqdm import tqdm
import os
import numpy as np
import pandas as pd
from PIL import Image, ImageCms
from datasets.basic_dataset_scaffold import BaseDataset

class CelebADataset(BaseDataset):
    def __init__(self, image_dict, opt, metadata, is_validation=False):
        super(CelebADataset, self).__init__(image_dict, opt, is_validation=is_validation)
        
        image_size = 256
        # self.f_norm = normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # imagenet mean, std
        self.f_norm = normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        transf_list = []
        
        if not self.is_validation:
            transf_list.append(transforms.RandomHorizontalFlip())
        
        transf_list.extend([
                transforms.Resize(image_size),
                transforms.CenterCrop(self.crop_size),
                transforms.ToTensor(),
                normalize
                ])
        self.normal_transform = transforms.Compose(transf_list)
                
        self.metadata = metadata
        
        ##### RE INDEXiNG METADATA TO MATCH IMAGE LIST
        if hasattr(self, 'image_list'):
            self.metadata['imagepath'] = pd.Categorical(self.metadata['imagepath'], [x[0] for x in self.image_list])
            self.metadata = self.metadata.sort_values('imagepath')
    
    def get_attribute(self, idx, attributes=['Pale_Skin']):
        items = self.metadata[attributes].iloc[idx].values
        return items

def crop_skin(image, crop_info):
    assert len(crop_info) > 0

    # Create template for cropped images
    collage = Image.new('RGB', (crop_info[0]["size"]*len(crop_info), crop_info[0]["size"]*1))

    for i, info in enumerate(crop_info):
        point = info["point"]
        size = info["size"]
        # Crop the images and put them in the collage
        cropped = image.copy().crop((point[0], point[1], point[0]+size, point[1]+size))
        collage.paste(cropped, (size*i, 0))

    return collage

def calculate_ita(image):
    # Convert to Lab colourspace
    srgb_p = ImageCms.createProfile("sRGB")
    lab_p  = ImageCms.createProfile("LAB")
    rgb2lab = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "RGB", "LAB")
    Lab = ImageCms.applyTransform(image, rgb2lab)

    L, a, b = Lab.split()
    avg_L = np.asarray(L).mean()
    avg_b = np.asarray(b).mean()

    # Calculate ITA
    return np.arctan((avg_L-50)/avg_b)*(180/np.pi)
    
def find_fitzpatrick_category(ita, fitzpatrick_categories):
    if ita < fitzpatrick_categories[6]:
        return 6
    elif ita < fitzpatrick_categories[5]:
        return 5
    elif ita < fitzpatrick_categories[4]:
        return 4
    elif ita < fitzpatrick_categories[3]:
        return 3
    elif ita < fitzpatrick_categories[2]:
        return 2
    else:
        return 1

def get_skintone_categories(opt, datapath, annotationpath, image_filepaths, landmarks, bbox = None, fitzpatrick_categories = {
        1: 90, # 50 <= x
        2: 50, # 40 <= x < 50
        3: 40, # 30 <= x < 40
        4: 30, # 20 <= x < 30
        5: 20, # 10 <= x < 20
        6: 10  # x < 10
    }):
    skintone_annotations = os.path.join(annotationpath, "fitzpatrick_categories.csv")
    
    if os.path.exists(skintone_annotations):
        skintone_categories = pd.read_csv(skintone_annotations, sep=",", header=0)
        return skintone_categories
    
    skintone_categories = pd.DataFrame(columns=["imagepath", "Skintone"])
    iterat = tqdm(image_filepaths, desc="Retrieving skintones ...")
    for filename in iterat:
        image = Image.open(filename)
        
        item = landmarks.loc[landmarks['imagepath'] == filename]
        nose_x = item['nose_x']
        nose_y = item['nose_y']

        lefteye_x = item['lefteye_x']
        righteye_x = item['righteye_x']
        
        ## "Cheek" and nose area
        crop_info = [{"point": (lefteye_x - 10, nose_y - 10), "size": 20},
                    {"point": (righteye_x - 10, nose_y - 10), "size": 20},
                    {"point": (nose_x - 10, nose_y - 10), "size": 20}]

        # Get the fitzpatrick category for the image
        collage = crop_skin(image, crop_info)
        ita = calculate_ita(collage)
        fitzpatrick_category = find_fitzpatrick_category(ita, fitzpatrick_categories)
        
        # Write down the category
        skintone_categories = skintone_categories.append({"imagepath": filename, "Skintone": fitzpatrick_category}, ignore_index = True)
    
    skintone_categories["Skintone"] = skintone_categories["Skintone"].astype(int)
    skintone_categories.to_csv(skintone_annotations, sep=",", header=True, index=False)
    return skintone_categories

def preprocess_celeba(opt, datapath):
    annotationpath = os.path.join(datapath, 'Anno')
    imagepath = os.path.join(datapath, 'Img/img_align_celeba')
    splitpath = os.path.join(datapath, 'Eval')
    
    identity = pd.read_csv(os.path.join(annotationpath, 'identity_CelebA.txt'), header=None, sep=' ')
    identity.columns = ['imagepath', 'identity']
    
    split = pd.read_csv(os.path.join(splitpath, 'list_eval_partition.txt'), header=None, sep=' ')
    split.columns = ['imagepath', 'split']
    
    attribute_data = []
    with open(os.path.join(annotationpath, 'list_attr_celeba.txt'), 'r') as file:
        for i, line in enumerate(file):
            if i == 1:
                attr_cols = [s.strip() for s in line.split(' ') if len(s.strip()) > 0]
                cols = ['imagepath']+attr_cols
            elif i > 1:
                attribute_data.append([s.strip() for s in line.split(' ') if len(s.strip()) > 0])
    attribute = pd.DataFrame(data=attribute_data, columns=cols)
    attribute = attribute.astype({ **{attr_col: int for attr_col in attr_cols}, **{'imagepath': str} })
    attribute[attr_cols] = attribute[attr_cols].clip(lower=0,upper=1)
    
    dfs = [identity, split, attribute]
    metadata = pd.concat([df.set_index('imagepath') for df in dfs], axis=1)
    metadata.reset_index(inplace=True)
    metadata['imagepath'] = [os.path.join(imagepath, image) for image in metadata['imagepath']]
    
    ### Get landmarks
    landmark_data = []
    with open(os.path.join(annotationpath, 'list_landmarks_align_celeba.txt'), 'r') as file:
        for i, line in enumerate(file):
            if i == 1:
                lcols = [s.strip() for s in line.split(' ') if len(s.strip()) > 0]
                cols = ['imagepath']+lcols
            elif i > 1:
                landmark_data.append([s.strip() for s in line.split(' ') if len(s.strip()) > 0])
    landmarks = pd.DataFrame(data=landmark_data, columns=cols)
    landmarks = landmarks.astype({ **{lcol: int for lcol in lcols}, **{'imagepath': str} })
    landmarks['imagepath'] = [os.path.join(imagepath, image) for image in landmarks['imagepath']]
    
    ### Get bbox information
    bbox_data = []
    with open(os.path.join(annotationpath, 'list_bbox_celeba.txt'), 'r') as file:
        for i, line in enumerate(file):
            if i == 1:
                cols = [s.strip() for s in line.split(' ') if len(s.strip()) > 0]
            elif i > 1:
                bbox_data.append([s.strip() for s in line.split(' ') if len(s.strip()) > 0])
    bbox = pd.DataFrame(data=bbox_data, columns=cols)
    bbox = bbox.rename(columns = {'image_id': 'imagepath'})
    bbox = bbox.astype({col: int if col != 'imagepath' else str for col in bbox.columns}) 
    bbox['imagepath'] = [os.path.join(imagepath, image) for image in bbox['imagepath']]

    ### Get skintone information
    skintone_categories = get_skintone_categories(opt, datapath, annotationpath, metadata['imagepath'].values, landmarks=landmarks)
    metadata = metadata.merge(skintone_categories, how = "inner", on = "imagepath")
    
    ### Process labels such that identity labels are in order
    unique_labels = metadata['identity'].unique()
    inverse_conversion = dict(zip(unique_labels, range(len(unique_labels))))
    metadata['identity'].replace(to_replace = inverse_conversion, inplace=True)

    if opt.use_tv_split:
        train_metadata = metadata.loc[metadata['split'] == 0].drop(labels='split', axis=1)
        val_metadata = metadata.loc[metadata['split'] == 1].drop(labels='split', axis=1)
    else:
        train_metadata = metadata.loc[np.logical_or(metadata['split'] == 0, metadata['split'] == 1)].drop(labels='split', axis=1)
        val_metadata = None
    
    test_metadata = metadata.loc[metadata['split'] == 2].drop(labels='split', axis=1)

    conversion = dict((value, key) for key, value in inverse_conversion.items())
    return metadata, conversion, train_metadata, val_metadata, test_metadata

def Give(opt, datapath):
    image_size = 256
    crop_size = 224
    transform = [
                transforms.RandomHorizontalFlip(),
                transforms.Resize(image_size),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                ]
    
    image_sourcepath  = datapath
    metadata, conversion, train_metadata, val_metadata, test_metadata = preprocess_celeba(opt, datapath)
    
    train_image_dict = {key: train_metadata['imagepath'].loc[train_metadata['identity'] == key].values.tolist() for key in np.unique(train_metadata['identity'])}
    train_dataset = CelebADataset(train_image_dict, opt, metadata=train_metadata)
    train_conversion = {key: conversion[key] for key in np.unique(train_metadata['identity'].values)} 
    
    eval_dataset = CelebADataset(train_image_dict, opt, metadata=train_metadata, is_validation=True)
    eval_train_dataset = CelebADataset(train_image_dict, opt, metadata=train_metadata, is_validation=False)
    
    if opt.use_tv_split:
        assert val_metadata is not None
        val_image_dict = {key: val_metadata['imagepath'].loc[val_metadata['identity'] == key].values.tolist() for key in np.unique(val_metadata['identity'])}
        val_dataset = CelebADataset(val_image_dict, opt, metadata=val_metadata, is_validation=True)
        val_conversion = {key: conversion[key] for key in np.unique(val_metadata['identity'].values)}

        val_dataset.conversion = val_conversion
    else:
        val_dataset = None
    
    test_image_dict = {key: test_metadata['imagepath'].loc[test_metadata['identity'] == key].values.tolist() for key in np.unique(test_metadata['identity'])}
    test_dataset = CelebADataset(test_image_dict, opt, metadata=test_metadata, is_validation=True)
    test_conversion = {key: conversion[key] for key in np.unique(test_metadata['identity'].values)}
    
    train_dataset.conversion = train_conversion
    test_dataset.conversion = test_conversion
    eval_dataset.conversion = train_conversion
    eval_train_dataset.conversion = train_conversion
    
    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset}
