
from typing import List, Union

import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from aiau.oracles.oracle import Oracle

class DataManager(object):
    
    def __init__(self, indices: Union[List[int], np.ndarray], 
                 observations: Union[List[int], np.ndarray], 
                 targets: Union[List[int], np.ndarray],
                 initially_labelled_indices: Union[List[int], np.ndarray]):
        self.indices = indices
        self.observations = observations
        self.targets = targets
        self.initially_labelled_indices = initially_labelled_indices
        self.labelled_indices = set(initially_labelled_indices)

        self.targets_dict = {idx: target for idx, target in zip(indices, targets)}
        self.observations_dict = {idx: observation for idx, observation in zip(indices, observations)}
        self.full_X = np.array(observations)
        self.full_y = np.array(targets)

        self.labelling_history = defaultdict(list) # {iteration: [labelled_indices]}

        self.noisy_targets_dict = defaultdict(list) # {idx: [noisy_label, ..., noisy_label]}

        print(f"DataManager initialised with {len(self)} observations.")
        print(f"Full dataset shape: {self.full_X.shape}")
        print(f"Full target shape: {self.full_y.shape}")

    def initialise(self, oracle: "Oracle"):
        """Initialise the noisy targets for the initially labelled indices."""
        self._construct_initial_noisy_targets(oracle)
        self.construct_noisy_dataset()
        self.labelling_history[0] = list(self.labelled_indices)
        self.correlated_draws_labels = defaultdict(dict) # {iteration: {idx: noisy_label}}
        self._update_correlated_draws_labels(iteration=0, indices=self.initially_labelled_indices)

    def _update_correlated_draws_labels(self, iteration, indices):
        """Update the correlated draws labels for the given indices."""
        for idx in indices:
            self.correlated_draws_labels[iteration][idx] = self.noisy_targets_dict[idx][-1]

    def _construct_initial_noisy_targets(self, oracle: "Oracle"):
        """Construct a dictionary mapping index to noisy labels obtained from an oracle."""
        noisy_targets = oracle.query_target_value(self, self.initially_labelled_indices)
        for idx, target in zip(self.initially_labelled_indices, noisy_targets):
            self.noisy_targets_dict[idx].append(target)
    
    def update_noisy_targets(self, oracle: "Oracle", idx: Union[int, List[int]]):
        """Update the noisy labels of given indices"""
        if isinstance(idx, int):
            idx = [idx]
        noisy_targets = oracle.query_target_value(self, idx)
        for idx, target in zip(idx, noisy_targets):
            self.noisy_targets_dict[idx].append(target)

    def construct_noisy_dataset(self):
        """Construct the dataset from the observations and noisy labels."""
        X = []
        y = []
        for idx in self.labelled_indices:
            for target in self.noisy_targets_dict[idx]:
                X.append(self.observations_dict[idx])
                y.append(target)

        return np.array(X), np.array(y)
    
    def update_labelled_indices(self, idx: Union[int, List[int]], iteration: int):
        """Update the labelled indices.
        
        Args:
            idx (int or list): The index or indices to be labelled.
            iteration (int): The current iteration.
        """
        if isinstance(idx, int):
            idx = [idx]
        self.labelled_indices.update(idx)
        self.labelling_history[iteration] = list(idx)
        self._update_correlated_draws_labels(iteration, idx)
    


    def __len__(self):
        return len(self.indices)
    

