import torch
import torchvision.transforms as T
from datasets.funnybirds.funnybirds import FunnyBirds, funnybirds_augmentation, PART_IDS
import numpy as np

def get_funnybirds_attributes(data_paths,
                        normalize_data=True,
                        image_size=32,
                        **kwargs):
    fns_transform = [T.Resize(image_size, interpolation=T.functional.InterpolationMode.BICUBIC), 
                     T.ToTensor()]
    if normalize_data:
        mean = torch.tensor([.5, .5, .5])
        std = torch.tensor([.5, .5, .5])
        fns_transform.append(T.Normalize(mean, std))

    transform = T.Compose(fns_transform)

    return FunnyBirdsAttribtues(data_paths, transform=transform, augmentation=funnybirds_augmentation)
    

class FunnyBirdsAttribtues(FunnyBirds):

    def __init__(self, data_paths, transform=None, augmentation=None):
        super().__init__(data_paths, transform, augmentation)
        self.attributes = [f"{p}::{v}" for p in PART_IDS for v in self.metadata[p].drop_duplicates().values if v != "placeholder"]
        
        ## drop attributes without diversity
        self.attributes = [attr for attr in self.attributes if sum([attr.split("::")[0] in a for a in self.attributes]) > 1]
    
    def get_attribute_labels(self, i):
        row = self.metadata.iloc[i]
        return torch.tensor([1 if row[attr.split("::")[0]] == attr.split("::")[1] else 0 
                             for attr in self.attributes]).type(torch.long)

    def build_pos_neg_concept_indexes(self):
        ## Pos Samples
        concept_index_pos = {n: [] for i, n in enumerate(self.attributes)}
        for i in range(len(self)):
            for attr_idx, attr_label in enumerate(self[i][2]):
                if attr_label == 1:
                    concept_index_pos[self.attributes[attr_idx]].append(i)
        
        ## Neg Samples
        concept_index_neg = {n: [] for i, n in enumerate(self.attributes)}
        for c_key, idxs_pos in concept_index_pos.items():
            rng = np.random.default_rng(0)
            idxs_neg_all = list(set(np.arange(len(self))) - set(idxs_pos))
            replace = len(idxs_neg_all) < len(idxs_pos)
            idxs_neg = rng.choice(idxs_neg_all, len(idxs_pos), replace=replace)
            concept_index_neg[c_key] = idxs_neg

        return concept_index_pos, concept_index_neg
    
    def __getitem__(self, i):
        x, y = super().__getitem__(i)
        attr = self.get_attribute_labels(i)
        return x, y, attr