import os
import sys
sys.path.append("..")
import random
from utils import load_json, save_json


CONFIG = load_json("config.json")["sampling_config"]

def sample_meta_eval_data():
    """
    Samples data for meta-evaluation annotation tasks from environment and conversation datasets.
    
    This function creates two types of annotation datasets:
    1. Environment data: Samples query-state pairs to evaluate whether queries appropriately 
       expose the corresponding state information
    2. Conversation data: Samples conversation sessions with unchanged states to check for 
       conflicting information between dialogue content and current states
    Note: Skips the first round of conversations as they don't include state updates.
    
    Workflow:
    - Environment sampling: Randomly selects queries from sessions.json files across all 
      persona folders, paired with their exposed states for annotation
    - Conversation sampling: Randomly selects conversation sessions and their unchanged states
      (states not exposed in that conversation) to check for consistency conflicts
    
    Output files:
    - env_data.json: List of (query_id, query_data) tuples for environment annotation
    - conv_data.json: List of (conv_id, conversation_data) tuples for conversation annotation
      where conversation_data contains user_turns and a sample of unchanged_states
    """
    random.seed(CONFIG["random_seed"])

    


    # 1. Environment data
    env_data = load_json(CONFIG['env_data_fn'])
    queries_to_annotate = []

    for data in env_data:
        sessions = [period['sessions'] for period in data['periods']]
        all_state_updates = [period['updates'] for period in data['periods']]
        all_state_updates = [{'old': {key: item[key]['old'] for key in item}} for item in all_state_updates]
        schema = data['state_schema']
        
        all_queries = []
        all_ids = []
        for i, (session, state_updates) in enumerate(zip(sessions, all_state_updates)):
            this_ids = []
            for j in range(len(session)):
                this_ids.append(os.path.join(data['id'], f"{i}-{j}"))

                session_info = session[j]
                for state in session_info['exposed_states']:
                    state_val = session_info['exposed_states'][state]

                    if state_updates:
                        past_state_val = state_updates['old'][state]
                    else:
                        past_state_val = None
                    
                    all_state_vals = schema[state]

                    session_info['exposed_states'][state] = {
                        "this_state_value": state_val,
                        "past_state_value": past_state_val,
                        "all_state_values": all_state_vals
                    }

                all_queries.append(session_info)
            all_ids.append(this_ids)
        all_ids = sum(all_ids, [])


        sampled_queries = random.sample(
            list(zip(all_ids, all_queries)), k=CONFIG["sample_n_queries"])

        queries_to_annotate.extend(sampled_queries)


    print("Environment data-queries to annotate:", len(queries_to_annotate))


    # 2. Conversation data
    env_data = load_json(CONFIG['env_data_fn'])
    grouped_sessions = []

    for env_data_item in env_data:

        """ Get conversation data """
        periods_data = [item['sessions'] for item in env_data_item['periods']]
        for sessions in periods_data:
            for session in sessions:
                # print(session['exposed_states'])
                tmp = {k: v['this_state_value'] if isinstance(v, dict) else v  for k, v in session['exposed_states'].items()}
                session['exposed_states'] = tmp

        # Add messages to form conversation data

        agent_state_dir = os.path.join(CONFIG['conv_folder'], env_data_item['id'], 'interactions')
        agent_state_fns = [os.path.join(agent_state_dir, fn) for fn in sorted(list(os.walk(agent_state_dir))[0][2])]
        agent_states = [load_json(fn) for fn in agent_state_fns]

        for periods, all_messages in zip(periods_data, agent_states):
            for period, messages in zip(periods, all_messages):
                period['messages'] = messages

        conv_data = periods_data

        """ Get states and updates data """
        env_data_item_states = [item['state'] for item in env_data_item['periods']]

        tmp = []
        for i, (convs, states) in enumerate(zip(conv_data, env_data_item_states)):
            if i == 0:
                continue    # Skip the first round as it does not include any state updates

            for j, conv in enumerate(convs):
                conv_id = f"{env_data_item['id']}-{i}-{j}"

                unchanged_states = {
                    k: v for k, v in states.items() if k not in conv['exposed_states']}
                user_turns = [turn for turn in conv['messages']
                            if turn['role'] == "user"]

                tmp.append(
                    (conv_id, {"user_turns": user_turns, "unchanged_states": unchanged_states}))

        grouped_sessions.append(tmp)

    # Sample
    conv_to_annotate = []
    for sessions in grouped_sessions:
        sampled_sessions = random.sample(sessions, CONFIG['sample_n_sessions'])
        for idx, data in sampled_sessions:
            if len(data['unchanged_states']) > CONFIG['sample_n_states']:
                sampled_states = random.sample(
                    list(data['unchanged_states']), CONFIG['sample_n_states'])
            else:
                sampled_states = list(data['unchanged_states'])

            data['unchanged_states'] = {
                k: data['unchanged_states'][k] for k in sampled_states}

        conv_to_annotate.extend(sampled_sessions)

    print("Conversation data-conversation sessions to annoatate:", len(conv_to_annotate),
        f"---- The first one has {len(conv_to_annotate[0][1]['unchanged_states'])} states.")

    # 3. QA data 
    env_data = load_json(CONFIG['env_data_fn'])
    sampled_qas = []
    for env_data_item in env_data:
        all_queries = []
        for state, qas in zip(env_data_item['periods'], env_data_item['qas']):
            query = qas['query']
            required_state_types = qas['required_info']
            answer_choices = qas['answer_choices']

            state = state['state']
            required_state_kvs = {k: state[k] for k in required_state_types}
            required_state_values = tuple([required_state_kvs[k] for k in required_state_types])

            # 
            choices = [answer['answer'] for answer in answer_choices]
            is_answer = [tuple(answer['state'])==required_state_values for answer in answer_choices]

            assert sum(is_answer) == 1
            
            all_queries.append({
                "query": query,
                "choices": choices,
                "is_answer": is_answer,
                "required_state": required_state_kvs
            })

        sampled_qas.extend(random.sample(all_queries, CONFIG['qa_n_query_per_item']))

    # 4. Write to files

    output_dir = CONFIG["output_dir"]
    os.makedirs(output_dir, exist_ok=True)

    save_json(os.path.join(output_dir, "env_data.json"), queries_to_annotate)
    save_json(os.path.join(output_dir, "conv_data.json"), conv_to_annotate)
    save_json(os.path.join(output_dir, "qa_data.json"), sampled_qas)

if __name__ == "__main__":
    sample_meta_eval_data()