from typing import Optional
import numpy as np
import pandas as pd
from lib.data.base import VLDataset

def split_data(inputs, targets, n_splits, seed=42):
    np.random.seed(seed)
    n_classes = len(np.unique(targets))
    
    ### adjust classes by removing exceeding ones
    n_classes = n_classes//n_splits * n_splits
    pos = targets < n_classes
    inputs = inputs[pos]
    targets = targets[pos]
    ###
    
    class_splits = np.random.permutation(np.arange(n_classes)).reshape(n_splits, -1)
    new_class_ids = {old_id: new_id for old_id, new_id in zip(np.arange(n_classes), class_splits.reshape(-1))}
    targets = np.array([new_class_ids[i] for i in targets])
    class_splits = np.arange(n_classes).reshape(n_splits, -1)
    inputs_experience = []
    targets_experience = []
    for i in range(n_splits):
        pos = pd.Series(targets).isin(class_splits[i]).values
        inputs_experience.append(inputs[pos])
        targets_experience.append(targets[pos])
    experiences = [dict(inputs=inputs, targets=targets) for inputs, targets in zip(inputs_experience, targets_experience)]
    return experiences

def subsample_dataset(
    dataset: VLDataset, 
    n_samples_per_class: Optional[int] = None,
    percent_samples_per_class: Optional[float] = None
):
    pos_per_class = [np.where(i==dataset.labels)[0] for i in np.unique(dataset.labels)]
    
    if n_samples_per_class is not None:
        assert n_samples_per_class <= min([len(x) for x in pos_per_class]), "Samples per class is too high, some classes have less examples!"
        subsampled_pos_per_class = [np.random.choice(x, n_samples_per_class, replace=False) for x in pos_per_class]
    
    if percent_samples_per_class is not None:
        subsampled_pos_per_class = [np.random.permutation(x)[:int(len(x)*percent_samples_per_class)] for x in pos_per_class]
    
    pos = np.concatenate(subsampled_pos_per_class)
    images = dataset.images[pos]
    texts = dataset.texts[pos]
    labels = dataset.labels[pos]
    
    return VLDataset(np.c_[images, texts],
                     labels,
                     dataset.image_processor,
                     dataset.text_processor)

def merge_vl_datasets(dataset_1: VLDataset, dataset_2: VLDataset):
    images = np.concatenate([dataset_1.images, dataset_2.images])
    texts = np.concatenate([dataset_1.texts, dataset_2.texts])
    labels = np.concatenate([dataset_1.labels, dataset_2.labels])
    return VLDataset(np.c_[images, texts], 
                     labels, 
                     dataset_1.image_processor, 
                     dataset_1.text_processor)
    
def compute_dataset_sampling_weights(dataset: VLDataset):
    labels = dataset.labels
    freq = np.bincount(labels)
    mapping = {i:j for i,j in zip(range(len(freq)), 1/freq)}
    w = [mapping[i] for i in labels]   
    return w