import pandas as pd
import numpy as np
from pathlib import Path
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import copy
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer

MODALITY = ["vision", "text"]

def Give(opt, label_names, datapath):
    datapath = Path(datapath)
    df = pd.read_pickle(datapath/'preprocessed'/'meta_df.pkl')
    
    opt.label_names = label_names

    if opt.exclusive:        
        mask = df[opt.label_names].sum(axis = 1) == 1
        df = df[mask]
        # opt.label_names = [i for i in opt.label_names if df[i].sum() > 100] # remove rare classes
        # mask2 = df[opt.label_names].sum(axis = 1) == 1# remove samples with dropped class
        # df = df[mask2] 

    # The square size to center-crop the images to
    image_size = 224

    transform_list = []
    # Add augmentation
    if opt.augmentation:
        transform_list.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15)
        ])
    transform_list.extend([
            transforms.ToTensor()
            ])
    # normalize the values of the images
    if not opt.not_pretrained:
        transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]))
    # compose the transforms together
    train_transform = transforms.Compose(transform_list)

    eval_transform = transforms.Compose([
        transforms.Resize([image_size, image_size]),
        transforms.ToTensor()
    ])

    train_dataset = MMIMDBDataset(df[df.fold == 'train'],  opt, train_transform, image_size, **vars(opt)) 
    eval_dataset = MMIMDBDataset(df[df.fold == 'train'], opt, eval_transform, image_size,**vars(opt)) 
    eval_train_dataset = MMIMDBDataset(df[df.fold == 'train'], opt, train_transform, image_size,**vars(opt)) 
    val_dataset = MMIMDBDataset(df[df.fold == 'val'], opt,  eval_transform, image_size,**vars(opt))
    test_dataset = MMIMDBDataset(df[df.fold == 'test'], opt,  eval_transform, image_size,**vars(opt))

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train': eval_train_dataset}

class MMIMDBDataset(Dataset):
    def __init__(self, df, opt, transform, image_size, **kwargs):
        super().__init__()
        self.df = df.reset_index(drop = True)
        self.opt = opt
        self.multimodal = opt.multimodal
        self.transform = transform
        self.image_size = image_size
        self.df.path = self.df.path.apply(lambda x: str(Path(opt.source_path)/'dataset'/Path(x).name))

        image_dict = {}
        for c, i in enumerate(self.opt.label_names):
            image_dict[c] = self.df[self.df[i]][['path']].reset_index()[['path', 'index']].values.tolist()

        self.image_dict = image_dict
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        self.init_setup()

    def init_setup(self):
        self.n_files = np.sum([len(self.image_dict[key]) for key in self.image_dict.keys()])
        self.avail_classes = sorted(list(self.image_dict.keys()))
        
        df = pd.concat([pd.DataFrame(self.image_dict[key]).assign(label = key) for key in self.image_dict])
        df.columns = ['path', 'idx', 'label']
        df = df.sort_values(by = 'idx', ascending = True)
        idx_mapping = df.drop_duplicates(subset = ['path']).set_index('path')['idx']
        agg_list = df.groupby('path').agg({'label': lambda x: list(x)}).loc[idx_mapping.index].reset_index().to_numpy().tolist()
        
        if self.opt.exclusive:
            for i in agg_list:
                assert len(i[-1]) == 1 # one label per image
                i[-1] = i[-1][0]
        else:
            for i in agg_list:
                i[-1] = np.array(i[-1])
        self.image_list = agg_list
        self.image_paths = self.image_list
        self.is_init = True
    
    def __getitem__(self, idx):
        if isinstance(idx, int):
            row = self.df.iloc[idx]
        else:
            row = self.df[self.df.path == idx].iloc[0]
            idx = row.name

        if self.opt.exclusive:
            labels = (row[self.opt.label_names].values).astype(int).nonzero()[0]
        else:
            labels = (row[self.opt.label_names].values).astype(int)

        img_path = row['path']
        img = Image.open(img_path)
        if len(img.size) == 2: 
            img = img.convert('RGB')
        
        img = transforms.Compose([
                        transforms.Resize([self.image_size, self.image_size])
                        ])(img)
        # Apply additional transform
        if self.transform is not None:
            img = self.transform(img)
            
        text = row['all_text']
        text_enc = self.tokenizer.encode_plus(text, return_tensors="pt", max_length = 512,
                                        padding = 'max_length', truncation = True)
        input_ids, token_type_ids, attention_mask  = text_enc['input_ids'][0, :], text_enc['token_type_ids'][0, :], text_enc['attention_mask'][0, :]
       
        if self.multimodal:
            return {
                'labels': labels, 
                'x':img, 
                'input_ids': input_ids, 
                'token_type_ids': token_type_ids,
                'attention_mask': attention_mask, 
                'idx': idx
            }
        else:
            if "vision" in self.opt.modality:
                return {
                    'labels': labels, 
                    'x':img, 
                    'idx': idx
                }
            elif "text" in self.opt.modality:
                return {
                    'labels': labels, 
                    'input_ids': input_ids, 
                    'token_type_ids': token_type_ids,
                    'attention_mask': attention_mask, 
                    'idx': idx
                }
            

    def __len__(self):
        return len(self.df)