from src.searchlight.utils import AbstractLogged
# from searchlight.headers import State

from typing import Any


import json
import numpy as np

class DataLoader(AbstractLogged):

    def __init__(self, rng: np.random.Generator = np.random.default_rng()):
        super().__init__()
        self.data = []  # This will store the list of data points
        self.rng = rng  # RNG instance for random operations

    def sample_data_point(self) -> tuple[list[str], tuple, list[Any], list[str], list[int], list[str], list[int], ]:
        '''
        Returns a data point (tuple) containing the following:
        - discussion_history_summary: string containing the history of the discussion for each agent
        - state_info: current state of the game, which is a tuple
        - intended_actions: list of intended actions of each agent (or final action)
        - private_informations: list of private information strings for each agent
        - roles: list[int] of roles for each player
        - dialogue: list[str] of the dialogues generated by each agent starting from the quest leader in order
        - speaking_order: list[int] of the speaking order of the agents for the dialogue (i.e who said what when), should have same length as dialogue
        '''
        if not self.data:
            raise ValueError("No data available to sample.")
        # Randomly select a data point using the RNG
        index = self.rng.integers(len(self.data))
        self.logger.info(f"Sampling data point at index {index}")
        self.logger.info(f"State tuple: {self.data[index][1]}")
        return self.data[index]
    
    def save_data(self, save_path: str):
        json_data = [self.convert_to_serializable(dp) for dp in self.data]
        with open(save_path, 'w') as f:
            json.dump(json_data, f)
        self.logger.info(f"Data saved to {save_path}")

    def load_data(self, load_path: str):
        with open(load_path, 'r') as f:
            json_data = json.load(f)
        self.data = [self.convert_from_serializable(dp) for dp in json_data]
        self.logger.info(f"Data loaded from {load_path}")

    def convert_to_serializable(self, data_point):
        discussion_history_summary, state_info, intended_actions, private_informations, roles, dialogue, speaking_order = data_point
        return {
            'discussion_history_summary': discussion_history_summary,
            'state_info': list(state_info),
            'intended_actions': intended_actions,
            'private_informations': private_informations,
            'roles': roles,
            'dialogue': dialogue,
            'speaking_order': speaking_order
        }

    def convert_from_serializable(self, serializable_data):
        return (
            serializable_data['discussion_history_summary'],
            self.recursive_convert_all_lists_to_tuples(serializable_data['state_info']),
            serializable_data['intended_actions'],
            serializable_data['private_informations'],
            serializable_data['roles'],
            serializable_data['dialogue'],
            serializable_data['speaking_order']
        )
    
    def recursive_convert_all_lists_to_tuples(self, state_info):
        new_state_info = []
        for item in state_info:
            if isinstance(item, list):
                new_state_info.append(self.recursive_convert_all_lists_to_tuples(item))
            else:
                new_state_info.append(item)
        return tuple(new_state_info)

    def add_data_point(self, 
                       discussion_history_summary: list[str], 
                       state_info: tuple[int, int, int, int, int, bool, bool, list[int], list[bool], list[bool], list[bool], list[int]], # TODO: next time use dataclass for this
                       intended_actions: list[Any], 
                       private_informations: list[str], 
                       roles: list[int], 
                       dialogue: list[str], 
                       speaking_order: list[int]):
        data_point = (discussion_history_summary, 
                      state_info, 
                      intended_actions, 
                      private_informations, 
                      roles, 
                      dialogue, 
                      speaking_order)
        self.data.append(data_point)
        self.logger.info("Data point added")

    @staticmethod
    def redact_state_info(state_tuple: tuple) -> tuple:
        '''
        Redacts the state info tuple to remove any hidden information
        '''
        num_players, quest_leader, phase, turn, round, done, good_victory, quest_team, team_votes, quest_votes, quest_results, roles = state_tuple
        
        return tuple([quest_leader, phase, turn, round, quest_team, team_votes, quest_results,])

    