import os
import json
import glob
from typing import Dict, List
import shutil
from collections import defaultdict, Counter

# Directory containing results
results_dir = 'FULL_UNFILTERED'
aggregated_dir = os.path.join(results_dir, results_dir + '_aggregated')

# Create AGGREGATED directory if it doesn't exist
# remove dir if it exists
if os.path.exists(aggregated_dir):
    shutil.rmtree(aggregated_dir)
os.makedirs(aggregated_dir)

# Initialize a dictionary to hold aggregated answers for each page
aggregated_data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))

# Iterate over USER_ID directories
for user_id in os.listdir(results_dir):
    user_dir = os.path.join(results_dir, user_id)
    if os.path.isdir(user_dir) and user_id != 'AGGREGATED':
        # Check length of finished_pages.json
        finished_pages_path = os.path.join(user_dir, 'finished_pages.json')
        if os.path.exists(finished_pages_path):
            with open(finished_pages_path, 'r') as f:
                finished_pages = json.load(f)
                if len(finished_pages) >= 10:
                    # Aggregate answers for each SAMPLE_ID/PAGE_NAME
                    for sample_id in os.listdir(user_dir):
                        user_sample_dir = os.path.join(user_dir, sample_id)
                        prompts_sample_dir = os.path.join('static/images/prompts', sample_id)
                        if os.path.isdir(user_sample_dir):
                            
                            n_people_path = os.path.join(prompts_sample_dir, 'n_people.json')
                            assert os.path.exists(n_people_path), f'n_people.json does not exist in {prompts_sample_dir}'
                            with open(n_people_path, 'r') as f:
                                n_people_data = json.load(f)
                                assert 'n_people' in n_people_data, f'n_people key not found in {n_people_path}'
                                n_people = n_people_data['n_people']
                            
                            for page_name in os.listdir(user_sample_dir):
                                page_dir = os.path.join(user_sample_dir, page_name)
                                
                                # Use glob to find all answers*.json files
                                answers_files = glob.glob(os.path.join(page_dir, 'answers*.json'))
                                
                                for answers_path in answers_files:
                                    with open(answers_path, 'r') as f:
                                        answers = json.load(f)
                                        # Reorganize data so that questions are keys
                                        for question, answer_data in answers.items():
                                            new_answer_data = []
                                            for item in answer_data:
                                                if item == 'ignoring':
                                                    item = 'not considering'
                                                elif item == 'ignore':
                                                    item = 'not consider'
                                                new_answer_data.append(item)
                                            answer_data = new_answer_data
                                            aggregated_data[sample_id][page_name][question][user_id] = answer_data
                                            aggregated_data[sample_id]['n_people'] = n_people

images_dir = os.path.join('static', 'images', 'train')

sample_info: Dict[str, List[Dict[str, str | int]]] = {'samples': []}
# Write aggregated answers to the new structure under AGGREGATED/SAMPLE_ID/answers.json
for sample_id, pages_data in aggregated_data.items():
    sample_dir = os.path.join(aggregated_dir, sample_id)
    os.makedirs(sample_dir, exist_ok=True)
    
    # combine all questions into one json
    all_questions_data: dict[str, dict] = {}
    for page_name, questions_data in pages_data.items():
        if page_name == 'n_people':
            continue
        all_questions_data.update(questions_data)
    answers_path = os.path.join(sample_dir, 'raw_answers.json')
    
    sample_info['samples'].append({'folderpath': sample_id, 'n_people': 0})
    with open(answers_path, 'w') as f:
        json.dump(all_questions_data, f, indent=4)

    def get_top_common_answers(data):
        result = {}
        for key, value in data.items():
            # Flatten list of answers
            if 'robot_moving_direction' in key:
                all_answers = [tuple(answers) for answers in value.values()]
            else:
                all_answers = [answers[0] for answers in value.values()]
            total_n_answers = len(all_answers)
            # Get frequency of each answer
            counter = Counter(all_answers)
            # Get the answers sorted by their frequency from most common to least common
            most_common = counter.most_common()  # returns a list of tuples (answer, count)
            # Store just the answers, sorted by frequency
            sorted_answers = [answer for answer, count in most_common]
            key = '_'.join(key.split('_')[:-1])
            result[key] = sorted_answers
            sorted_probabilities = [count / total_n_answers for answer, count in most_common]
            result[key + '_probabilities'] = sorted_probabilities
            result[key + '_n_answers'] = total_n_answers
        return result
    
    most_common_data = get_top_common_answers(all_questions_data)
    common_answers_path = os.path.join(sample_dir, 'common_answers.json')
    with open(common_answers_path, 'w') as f:
        json.dump(most_common_data, f, indent=4)

with open(os.path.join(aggregated_dir, 'sample_info.json'), 'w') as f:
    json.dump(sample_info, f, indent=4)

print(f'Aggregation completed. Aggregated data is stored in: {aggregated_dir}')
