import numpy as np
import logging
import os

class ExpData:
    ######################### train/test data #########################
    X_train_total = None
    Y_train_total = None
    X_val_total = None
    Y_val_total = None
    order_list = None
    path_dataset_prefix_cloned: str = ""
    ######################## End of Attributes ########################

    def __init__(self, exp, trainset, testset):
        logger = logging.getLogger()
        self.order_list = gen_class_mapping(exp.ds.n_tot_class, exp.rngs['class_order'])
        logger.info(f'Order_list generated for this run: {self.get_order_dict_str(exp.ds.n_tot_class)}')
        self.path_dataset_prefix_cloned = exp.path_dataset_prefix

        self.X_train_total, self.Y_train_total = split_images_labels(trainset.imgs)
        self.Y_train_total = self.order_list[self.Y_train_total]

        self.X_val_total, self.Y_val_total = split_images_labels(testset.imgs)
        self.Y_val_total = self.order_list[self.Y_val_total]

        self.store_current_data(exp.path_expfolder)
    
    def store_current_data(self, d_fullpath):
        np.savez(os.path.join(d_fullpath, 'dataset.npz'), 
                 X_train_total=self.X_train_total, 
                 Y_train_total=self.Y_train_total, 
                 X_val_total=self.X_val_total, 
                 Y_val_total=self.Y_val_total, 
                 order_list=self.order_list)
    
    def get_order_dict_str(self, num_class):
        order_dict = {i: self.order_list[i] for i in range(num_class)}
        dict_str = ', '.join([f"{key}:{value}" for key, value in order_dict.items()])
        return dict_str
    
    def get_n_task_data(self, exp):
        if self.X_train_total is None:
            return 0
        return len(self.X_train_total[exp.r.Ti_data_indices])

def split_images_labels(imgs):
    images = []
    labels = []
    for item in imgs:
        images.append(item[0])
        labels.append(item[1])

    return np.array(images), np.array(labels)

def merge_images_labels(images, labels):
    images = list(images)
    labels = list(labels)
    assert(len(images)==len(labels))
    imgs = []
    for i in range(len(images)):
        item = (images[i], labels[i])
        imgs.append(item)
    
    return imgs

def gen_class_mapping(num_class, rng):
    order_list = rng.permutation(num_class)
    return order_list
