import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import datasets
from transformers import AutoProcessor

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



class KneeOA(Dataset):
    """
    Knee Osteoarthritis dataset that contains X-ray images of knees, binary labels
    that indicate an osteoarthritis diagnosis, and 16 different JSW measurements.
    """

    def __init__(self, csv_df, include_extra=False, include_group=False, include_aux=False):
        """
        Initialize the datset.
        :param csv_df: Pandas dataframe with image paths and labels
        :param include_extra: Indicates if should return img path
        :param include_group: Indicates if should return a group number
        :param include_aux: Indicates if should return a an auxiliary label
        """
        self.csv_df = csv_df
        self.include_extra = include_extra
        self.jsw_columns = ['V01JSW150',
                            'V01JSW175',
                            'V01JSW200',
                            'V01JSW225',
                            'V01JSW250',
                            'V01JSW275',
                            'V01JSW300',
                            'V01LJSW700',
                            'V01LJSW725',
                            'V01LJSW750',
                            'V01LJSW775',
                            'V01LJSW800',
                            'V01LJSW825',
                            'V01LJSW850',
                            'V01LJSW875',
                            'V01LJSW900'
                            ]
        self.include_group = include_group
        self.include_aux = include_aux

    def __len__(self):
        return len(self.csv_df)

    def get_group_info(self):
        n_groups = 4
        group_counts = []
        for i in range(n_groups):
            group_counts.append(len(self.csv_df[(self.csv_df['group'] == i)]))
        group_counts = torch.tensor(group_counts)
        
        dataset_info = {'n_groups': n_groups, 'group_counts': group_counts}
        return dataset_info

    def __getitem__(self, idx):
        # Load X-ray image
        img_transform = transforms.Compose([transforms.PILToTensor()])
        img_path = self.csv_df.iloc[idx]['LOCATION']
        full_img = Image.open(img_path)
        full_img = img_transform(full_img).to(dtype=torch.float32) / 255.0

        # Obtain OA label
        label = self.csv_df.iloc[idx]['KLG']
        if label < 2:
            label = torch.tensor(0)
        else:
            label = torch.tensor(1)

        # Obtain group
        group = self.csv_df.iloc[idx]['group']
        group = torch.tensor(group)

        # Obtain auxiliary label
        aux_label = self.csv_df.iloc[idx]['aux_label']
        aux_label = torch.tensor(aux_label)

        # Obtain JSW measurements
        jsw_values = self.csv_df.iloc[idx][self.jsw_columns]
        jsw_values = torch.tensor(jsw_values.values.astype(np.float32))

        if self.include_extra:
            return (full_img, label, jsw_values, img_path)
        elif self.include_group:
            return (full_img, label, group)
        elif self.include_aux:
            return (full_img, label, aux_label)
        else:
            return (full_img, label, jsw_values)



class WaterBirds(Dataset):
    """
    Water birds dataset contains images of land and water birds over land and
    water backgrounds, binary labels indicating if the image contains a land
    or water bird, and image segmenations of only the bird.
    """

    def __init__(self, csv_df, include_extra=False, include_group=False, include_aux=False, include_med=False):
        """
        Initialize the datset.
        :param csv_df: Pandas dataframe with image paths and labels
        :param include_extra: Indicates if should return img path
        :param include_group: Indicates if should return a group number
        :param include_aux: Indicates if should return a an auxiliary label
        """
        self.csv_df = csv_df
        self.include_extra = include_extra
        self.include_group = include_group
        self.include_aux = include_aux
        self.include_med = include_med

    def __len__(self):
        return len(self.csv_df)

    def get_group_info(self):
        n_groups = 4
        group_counts = []
        for i in range(n_groups):
            group_counts.append(len(self.csv_df[(self.csv_df['group'] == i)]))
        group_counts = torch.tensor(group_counts)

        dataset_info = {'n_groups': n_groups, 'group_counts': group_counts}
        return dataset_info

    def __getitem__(self, idx):
        # Load main image
        img_transform = transforms.Compose([transforms.PILToTensor()])
        img_path = self.csv_df.iloc[idx]['full_img']
        full_img = Image.open(img_path)
        full_img = img_transform(full_img).to(dtype=torch.float32) / 255.0
        
        # Obtain label
        label = int(self.csv_df.iloc[idx]['label'])
        label = torch.tensor(label)

        # Obtain group
        group = self.csv_df.iloc[idx]['group']
        group = torch.tensor(group)

        # Obtain auxiliary label
        aux_label = self.csv_df.iloc[idx]['aux_label']
        aux_label = torch.tensor(aux_label)

        # Obtain bird segmentation
        bird_segmentation = Image.open(self.csv_df.iloc[idx]['bird_seg'])
        bird_segmentation = img_transform(bird_segmentation).to(dtype=torch.float32) / 255.0

        # Obtain full segmentation
        full_segmentation = Image.open(self.csv_df.iloc[idx]['full_seg'])
        full_segmentation = img_transform(full_segmentation).to(dtype=torch.float32) / 255.0

        if self.include_extra:
            return (full_img, label, bird_segmentation, img_path)
        elif self.include_group:
            return (full_img, label, group)
        elif self.include_aux:
            return (full_img, label, aux_label)
        elif self.include_med:
            return (full_img, label, full_segmentation)
        else:
            return (full_img, label, bird_segmentation)



class FoodReview(Dataset):


    def __init__(self, csv_df, tokenizer, include_extra=False, include_group=False, include_aux=False, include_med=False):
        """
        Initialize the datset.
        :param csv_df: Pandas dataframe with image paths and labels
        :param include_img_path: Indicates if should return img path
        :param include_group: Indicates if should return a group number
        :param include_aux: Indicates if should return a an auxiliary label
        """
        self.csv_df = csv_df
        self.include_extra = include_extra
        self.tokenizer = tokenizer
        self.include_group = include_group
        self.include_aux = include_aux
        self.include_med = include_med

    def __len__(self):
        return len(self.csv_df)

    def get_group_info(self):
        n_groups = 4
        group_counts = []
        for i in range(n_groups):
            group_counts.append(len(self.csv_df[(self.csv_df['group'] == i)]))
        group_counts = torch.tensor(group_counts)
        
        dataset_info = {'n_groups': n_groups, 'group_counts': group_counts}
        return dataset_info

    def __getitem__(self, idx):

        # Load review
        if self.include_med:
            review = self.csv_df.iloc[idx]['Text']
            review_tokens = "summarize: " + review
        else:
            review = self.csv_df.iloc[idx]['Text']
            review_tokens = self.tokenizer(review, padding='max_length', max_length=512, truncation=True, return_tensors='pt')

        # Load summary
        summary = str(self.csv_df.iloc[idx]['Summary'])
        summary_tokens = self.tokenizer(summary, padding='max_length', max_length=512, truncation=True, return_tensors='pt')

        # Load score
        score = self.csv_df.iloc[idx]['Score']
        if score < 4:
            score = torch.tensor(0)
        else:
            score = torch.tensor(1)

        # Obtain group
        group = self.csv_df.iloc[idx]['group']
        group = torch.tensor(group)

        # Obtain auxiliary label
        aux_label = self.csv_df.iloc[idx]['aux_label']
        aux_label = torch.tensor(aux_label)

        if self.include_extra:
            return (review_tokens, score, summary_tokens, review)
        elif self.include_group:
            return (review_tokens, score, group)
        elif self.include_aux:
            return (review_tokens, score, aux_label)
        elif self.include_med:
            return (review_tokens, score, summary)
        else:
            return (review_tokens, score, summary_tokens)
    
    def get_med_dataset(self, idx):

        review = self.csv_df.iloc[idx]['Text'].astype(str)
        summary = self.csv_df.iloc[idx]['Summary'].astype(str)
        ds = datasets.Dataset.from_dict({"review": review.tolist(), "summary": summary.tolist()})

        tokenized_dataset = ds.map(preprocess_function, batched=True, fn_kwargs={"tokenizer": self.tokenizer})

        return tokenized_dataset

  
def preprocess_function(ds, tokenizer=None):

    inputs = ["summarize: " + review for review in ds["review"]]
    model_inputs = tokenizer(inputs, max_length=256, truncation=True, padding='max_length')

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=ds["summary"], max_length=48, truncation=True, padding='max_length')

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

class TeacherDataset(Dataset):
    """
    Dataset that contains the logits generated by a teacher model and the covariates
    used to generate the logits
    """

    def __init__(self, x, logits, dataset, tokenizer=None):
        """
        Initialize the datset.
        :parm x: Data used to make a prediction
        :param logits: Logits produced by a teacher model 
        :param dataset: Name of dataset being used
        """
        self.x = x
        self.logits = logits
        self.dataset = dataset
        self.tokenizer = tokenizer


    def __len__(self):
        return len(self.logits)

    def __getitem__(self, idx):

        # Get image
        if self.dataset == 'koa' or self.dataset == 'koa_double':
            img_transform = transforms.Compose([transforms.PILToTensor()])
            img_path = self.x[idx]
            full_img = Image.open(img_path)
            x = img_transform(full_img).to(dtype=torch.float32) / 255.0
        elif self.dataset == 'waterbirds' or self.dataset == 'waterbirds_double':
            img_transform = transforms.Compose([transforms.PILToTensor()])
            img_path = self.x[idx]
            full_img = Image.open(img_path)
            x = img_transform(full_img).to(dtype=torch.float32) / 255.0
        elif self.dataset == 'food_review' or self.dataset == 'food_review_double':
            x = self.tokenizer(self.x[idx], padding='max_length', truncation=True, return_tensors='pt')

        # Obtain logit
        logit = self.logits[idx]

        # Dummy label for compatibility
        dummy_label = torch.tensor(0)
        
        return (x, logit, dummy_label)



class MediatorDataset(Dataset):
    """
    Dataset that contains the mediators generated by a meditator model and the features
    used to generate the mediators
    """

    def __init__(self, x, mediators, labels, dataset, tokenizer=None):
        """
        Initialize the datset.
        :parm x: Data used to make a prediction
        :param meditators: Mediators produced by a mediator model 
        :param dataset: Name of dataset being used
        """
        self.x = x
        self.mediators = mediators
        self.dataset = dataset
        self.labels = labels
        self.tokenizer = tokenizer


    def __len__(self):
        return len(self.mediators)

    def __getitem__(self, idx):

        # Get image
        if self.dataset == 'koa' or self.dataset == 'koa_double':
            img_transform = transforms.Compose([transforms.PILToTensor()])
            img_path = self.x[idx]
            full_img = Image.open(img_path)
            x = img_transform(full_img).to(dtype=torch.float32) / 255.0
        elif self.dataset == 'waterbirds' or self.dataset == 'waterbirds_double':
            img_transform = transforms.Compose([transforms.PILToTensor()])
            img_path = self.x[idx]
            full_img = Image.open(img_path)
            x = img_transform(full_img).to(dtype=torch.float32) / 255.0
        elif self.dataset == 'food_review' or self.dataset == 'food_review_double':
            x = self.tokenizer(self.x[idx], padding='max_length', truncation=True, return_tensors='pt')

        # Obtain mediator
        mediator = self.mediators[idx]
        
        # Obtain main label
        label = self.labels[idx]

        return (x, label, mediator)
