import os
import yaml
import json
import base64
import shutil
import numpy as np
import tqdm
import cv2
from typing import List, Tuple, Dict
from itertools import chain, combinations
from baseline import Baseline
from aggregate_eval_data import process_evaluation_file
from survey_loader import load_survey_questions_independent, load_survey_questions_cot, load_survey_questions_cot_with_gt, get_image_prompt
from utils import load_yaml, load_model_class, compute_top_k_human_accuracy, Answer, compute_cross_entropy, compute_top_k_accuracy, load_human_answer
from utils import copy_config_files, validate_prompts_in_human_answers, save_debug_images, REASONING_GROUPS
import argparse


def evaluate_baseline(baseline_model: str, 
                      survey_folder: str, 
                      prompts_folder: str,
                      evaluation_folder: str,
                      model_to_api_key: Dict,
                      prompt_image_fp: str,
                      method: str = 'independent',
                      prompts_data: str = 'prompts/survey_prompt.json',
                      samples_json_fp: str = 'prompts/sample_info.json',
                      relevant_prev_qs_fp: str = 'prompts/relevant_prev_questions.json',
                      config: Dict = None,
                      dataset_cfg: str = 'dataset_cfg.yaml',
                      debug: bool = False,
                      resume_folder: str = None):
    model = load_model_class(baseline_model, model_to_api_key)
    assert method in ['independent', 'cot', 'cot_with_gt'], 'Invalid method'
    
    # Load sample directories and corresponding answers
    with open(samples_json_fp, 'r') as f:
        samples = json.load(f)
        sample_base_folders = [sample['folderpath'] for sample in samples['samples']]
        sample_directories = [os.path.join(prompts_folder, s) for s in sample_base_folders]
        answers_directories = [os.path.join(survey_folder, s) for s in sample_base_folders]
        n_people_per_sample = [sample['n_people'] for sample in samples['samples']]
        
    for sample_dir in sample_directories:
        assert os.path.exists(sample_dir), f"Could not find {sample_dir}"
        
    for answer_dir in answers_directories:
        assert os.path.exists(answer_dir), f"Could not find {answer_dir}"
        
    assert len(sample_directories) == len(answers_directories), "Mismatch in directory counts."
    
    eval_dir = os.path.join(evaluation_folder, f'human_ground_truth')
    
    print(f"Saving evaluation results to {eval_dir}")
    os.makedirs(eval_dir, exist_ok=True)
        
    n_samples_processed = 0
    
    # Load relevant previous questions mapping
    with open(relevant_prev_qs_fp, 'r') as f:
        relevant_prev_qs = json.load(f)
    
    if resume_folder is None:
        copy_config_files(eval_dir, prompts_data, samples_json_fp, relevant_prev_qs_fp, dataset_cfg)
        
    dataset_config = load_yaml(dataset_cfg)

    # Process each sample
    for (sample_dir, answer_dir, n_people) in zip(sample_directories, answers_directories, n_people_per_sample):
        prompts = load_survey_questions_independent(prompts_data, n_people)
        
        validate_prompts_in_human_answers(answer_dir, prompts)

        sample_id = os.path.basename(sample_dir)
        print(f"\nProcessing sample: {sample_id} ({n_samples_processed + 1}/{len(sample_directories)})")
        eval_sample_dir = os.path.join(eval_dir, sample_id)
        os.makedirs(eval_sample_dir, exist_ok=True)
        
        # Prepare image prompt
        sample_parent_dir = "vlm-sn/scand_spot/frames/val"
        sample_filename = f"{sample_parent_dir}/{sample_id.split('_')[0]}_{sample_id.split('_')[1]}_{sample_id.split('_')[2]}/{sample_id.split('_')[3]}.jpg"
        
        prompt_config = config
        images_prompt = get_image_prompt(dataset_cfg, sample_filename, prompt_config)
        
        if model.baseline_type == 'api':
            for idx, img_np in enumerate(images_prompt):
                img_data = cv2.resize(img_np, (512, 512))
                img_data = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)
                encoded_cv2 = cv2.imencode('.jpg', img_data)[1]
                img_base64 = base64.b64encode(encoded_cv2).decode()
                images_prompt[idx] = img_base64
        
        # Save debug images if requested
        if debug:
            save_debug_images(baseline_model, model, images_prompt, sample_id)

        human_answers_fp = os.path.join(answer_dir, 'common_answers.json')

        actual_conversation_history = []
        question_answers_export = {}
        eval_export = {}
        question_answers = {}
        human_question_answers = {}
        n_vlm_queries = 1

        for i, (question_key, prompt, choices, question_type) in enumerate(tqdm.tqdm(prompts)):
            # then we can get the gt answer from the human answers
            human_answers = load_human_answer(human_answers_fp, question_key, choices, question_type)
            clean_answer = human_answers.get_most_common_answer()
            generated_text = json.dumps({'answer': clean_answer})
            
            question_answers_export[question_key] = clean_answer
            if n_vlm_queries == 1:
                question_answers[question_key] = Answer([clean_answer], [1.0], choices, len(choices), question_type)
            else:
                raise NotImplementedError('Multiple VLM queries not implemented')
            
            if 'cot' in method:
                for j in range(i + 1, len(prompts)):
                    fqkey, fprompt, fchoices, fqtype = prompts[j]
                    answer_dummy_txt = "{QUESTION_KEY_ANSWER}".replace("QUESTION_KEY", question_key)
                    if isinstance(clean_answer, str):
                        clean_answer_with_dict = "{\"answer\": " + clean_answer + "}"
                    else:
                        clean_answer_with_dict = "{\"answer\": \"" + str(clean_answer) + "\"}"
                    fprompt = fprompt.replace(answer_dummy_txt, clean_answer_with_dict)
                    prompts[j] = (fqkey, fprompt, fchoices, fqtype)
            
            human_question_answers[question_key] = load_human_answer(human_answers_fp, question_key, choices, question_type)
            top_1_accuracy = compute_top_k_accuracy(question_answers[question_key], human_question_answers[question_key], 1)
            top_2_accuracy = compute_top_k_accuracy(question_answers[question_key], human_question_answers[question_key], 2)
            cross_entropy = compute_cross_entropy(question_answers[question_key], human_question_answers[question_key])
            vlm_entropy = -np.sum([p * np.log2(p) for p in question_answers[question_key].answers_probabilities if p > 0])
            human_entropy = -np.sum([p * np.log2(p) for p in human_question_answers[question_key].answers_probabilities if p > 0])
            top_1_human_accuracy = compute_top_k_human_accuracy(human_question_answers[question_key], 1)
            top_2_human_accuracy = compute_top_k_human_accuracy(human_question_answers[question_key], 2)
            top_1_random_accuracy = 1 / len(human_question_answers[question_key].choices)
            top_2_random_accuracy = 2 / len(human_question_answers[question_key].choices)

            eval_export[question_key] = {
                'top_1_accuracy': top_1_accuracy,
                'top_2_accuracy': top_2_accuracy,
                'top_1_human_accuracy': top_1_human_accuracy,
                'top_2_human_accuracy': top_2_human_accuracy,
                'top_1_random_accuracy': top_1_random_accuracy,
                'top_2_random_accuracy': top_2_random_accuracy,
                'cross_entropy': cross_entropy,
                'vlm_entropy': vlm_entropy,
                'human_entropy': human_entropy,
                'vlm_probabilities': question_answers[question_key].answer_to_probability,
                'human_probabilities': human_question_answers[question_key].answer_to_probability
            }
            
            if 'cot' in method:
                model.add_to_conversation_history(('user', prompt))
                model.add_to_conversation_history(('assistant', generated_text))
            
            with open(os.path.join(eval_sample_dir, 'question_answers.json'), 'w') as of:
                json.dump(question_answers_export, of, indent=4)

            common_answers_path = os.path.join(answer_dir, 'common_answers.json')
            human_common_answers_path = os.path.join(eval_sample_dir, 'human_common_answers.json')
            shutil.copy(common_answers_path, human_common_answers_path)

            eval_export[question_key]['vlm_probabilities'] = {str(k): v for k, v in eval_export[question_key]['vlm_probabilities'].items()}
            eval_export[question_key]['human_probabilities'] = {str(k): v for k, v in eval_export[question_key]['human_probabilities'].items()}
            with open(os.path.join(eval_sample_dir, 'evaluation.json'), 'w') as of:
                json.dump(eval_export, of, indent=4)
        
        n_samples_processed += 1


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--resume_folder', type=str, default=None, help='Folder to resume evaluation from')
    args = parser.parse_args()
    resume_folder = args.resume_folder
    if resume_folder:
        print(f"Resuming evaluation from {resume_folder}")
    
    config = load_yaml('eval_cfg.yaml')
    baseline_model = config['baseline_model']
    survey_folder = config['survey_folder']
    prompts_folder = config['prompts_folder']
    evaluation_folder = config['evaluation_folder']
    model_to_api_key = config['model_to_api_key']
    prompt_image_fp = config['prompt_image_fp']
    method = config['method']
    
    prompt_data_fp = os.path.join(prompts_folder, 'survey_prompt.json')
    samples_json_fp = os.path.join(prompts_folder, 'sample_info.json')
    relevant_prev_qs_fp = os.path.join(prompts_folder, 'relevant_prev_questions.json')
    
    dataset_cfg = config.get('dataset_cfg', None)
    
    evaluate_baseline(baseline_model=baseline_model,
                      survey_folder=survey_folder,
                      prompts_folder=prompts_folder,
                      evaluation_folder=evaluation_folder,
                      model_to_api_key=model_to_api_key,
                      prompt_image_fp=prompt_image_fp,
                      method=method,
                      prompts_data=prompt_data_fp,
                      samples_json_fp=samples_json_fp,
                      relevant_prev_qs_fp=relevant_prev_qs_fp,
                      config=config,
                      dataset_cfg=dataset_cfg,
                      debug=config.get('debug', False),
                      resume_folder=resume_folder)

if __name__ == "__main__":
    main()