import os
import yaml
import json
import base64
import shutil
import numpy as np
import tqdm
import cv2
from typing import Dict
from survey_loader import load_survey_questions_independent, load_survey_questions_cot, load_survey_questions_cot_with_gt, get_image_prompt
from utils import load_yaml, load_model_class, compute_top_k_human_accuracy, Answer, compute_cross_entropy, compute_top_k_accuracy, load_human_answer
from utils import copy_config_files, validate_prompts_in_human_answers, save_debug_images, REASONING_GROUPS
import argparse


def evaluate_baseline(baseline_model: str, 
                      survey_folder: str, 
                      prompts_folder: str,
                      evaluation_folder: str,
                      model_to_api_key: Dict,
                      prompt_image_fp: str,
                      method: str = 'independent',
                      prompts_data: str = 'prompts/survey_prompt.json',
                      samples_json_fp: str = 'prompts/sample_info.json',
                      relevant_prev_qs_fp: str = 'prompts/relevant_prev_questions.json',
                      config: Dict = None,
                      dataset_cfg: str = 'dataset_cfg.yaml',
                      debug: bool = False,
                      resume_folder: str = None):
    model = load_model_class(baseline_model, model_to_api_key)
    assert method in ['independent', 'cot', 'cot_with_gt'], 'Invalid method'
    
    # Load sample directories and corresponding answers
    with open(samples_json_fp, 'r') as f:
        samples = json.load(f)
        sample_base_folders = [sample['folderpath'] for sample in samples['samples']]
        sample_directories = [os.path.join(prompts_folder, s) for s in sample_base_folders]
        answers_directories = [os.path.join(survey_folder, s) for s in sample_base_folders]
        n_people_per_sample = [sample['n_people'] for sample in samples['samples']]
        
    for sample_dir in sample_directories:
        assert os.path.exists(sample_dir), f"Could not find {sample_dir}"
        
    for answer_dir in answers_directories:
        assert os.path.exists(answer_dir), f"Could not find {answer_dir}"
        
    assert len(sample_directories) == len(answers_directories), "Mismatch in directory counts."
        
    if resume_folder:
        processed_samples = [f for f in os.listdir(resume_folder)]
        sample_directories = [s for s in sample_directories if os.path.basename(s) not in processed_samples]
        eval_dir = resume_folder
    else:
        n_experiments = len([f for f in os.listdir(evaluation_folder) if os.path.isdir(os.path.join(evaluation_folder, f))])
        idx_new_experiment = n_experiments + 1
        eval_dir = os.path.join(evaluation_folder, f'experiment_{idx_new_experiment}_{baseline_model}_{method}')
        print(f"Saving evaluation results to {eval_dir}")
        os.makedirs(eval_dir, exist_ok=True)
        
    n_samples_processed = 0
    
    # Load relevant previous questions mapping
    with open(relevant_prev_qs_fp, 'r') as f:
        relevant_prev_qs = json.load(f)
    
    if resume_folder is None:
        copy_config_files(eval_dir, prompts_data, samples_json_fp, relevant_prev_qs_fp, dataset_cfg)

    # Process each sample
    for (sample_dir, answer_dir, n_people) in zip(sample_directories, answers_directories, n_people_per_sample):
        if method == 'independent':
            prompts = load_survey_questions_independent(prompts_data, n_people)
        elif method == 'cot':
            prompts = load_survey_questions_cot(prompts_data, n_people, relevant_prev_qs)
        elif method == 'cot_with_gt':
            prompts = load_survey_questions_cot_with_gt(prompts_data, n_people, relevant_prev_qs)
        else:
            raise ValueError('Invalid method')
        
        validate_prompts_in_human_answers(answer_dir, prompts)

        sample_id = os.path.basename(sample_dir)
        print(f"\nProcessing sample: {sample_id} ({n_samples_processed + 1}/{len(sample_directories)})")
        eval_sample_dir = os.path.join(eval_dir, sample_id)
        os.makedirs(eval_sample_dir, exist_ok=True)
        
        # Prepare image prompt
        sample_parent_dir = "scand_spot/frames/val"
        sample_filename = f"{sample_parent_dir}/{sample_id.split('_')[0]}_{sample_id.split('_')[1]}_{sample_id.split('_')[2]}/{sample_id.split('_')[3]}.jpg"
        
        prompt_config = config
        images_prompt = get_image_prompt(dataset_cfg, sample_filename, prompt_config)
        
        if model.baseline_type == 'api':
            for idx, img_np in enumerate(images_prompt):
                img_data = cv2.resize(img_np, (512, 512))
                img_data = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)
                encoded_cv2 = cv2.imencode('.jpg', img_data)[1]
                img_base64 = base64.b64encode(encoded_cv2).decode()
                images_prompt[idx] = img_base64
        
        # Save debug images if requested
        if debug:
            save_debug_images(baseline_model, model, images_prompt, sample_id)

        # If method == 'cot_with_gt', insert human answers into future prompts
        human_answers_fp = os.path.join(answer_dir, 'common_answers.json')

        actual_conversation_history = []
        question_answers_export = {}
        eval_export = {}
        question_answers = {}
        human_question_answers = {}
        n_vlm_queries = 1

        for i, (question_key, prompt, choices, question_type) in enumerate(tqdm.tqdm(prompts)):
            # determine reasoning group
            if question_key.startswith("q_goal_position"):
                base_question = question_key
            else:
                base_question = question_key.rsplit('_p', 1)[0]
            for group_name, questions in REASONING_GROUPS.items():
                if base_question in questions:
                    reasoning_group = group_name
                    break
            else:
                raise ValueError(f"Base question {base_question} not found in any reasoning group...")
            
            if (reasoning_group == "Spatial reasoning" or reasoning_group == "Spatiotemporal reasoning") and method == 'cot_with_gt':
                # then we can get the gt answer from the human answers
                human_answers = load_human_answer(human_answers_fp, question_key, choices, question_type)
                clean_answer = human_answers.get_most_common_answer()
                generated_text = json.dumps({'answer': clean_answer})
            else:
                generated_text = model.generate_text(prompt, images_prompt)
                actual_conversation_history.extend([
                    {'entity': ['user'], 'response': prompt.split('\n')}, 
                    {'entity': ['assistant'], 'response': [generated_text]}
                ])
                
                with open(os.path.join(eval_sample_dir, 'conversation.json'), 'w') as of:
                    json.dump(actual_conversation_history, of, indent=4)
                
                try:
                    clean_answer = json.loads(generated_text)['answer']
                except:
                    clean_answer = "INVALID"
            
            question_answers_export[question_key] = clean_answer
            if n_vlm_queries == 1:
                question_answers[question_key] = Answer([clean_answer], [1.0], choices, len(choices), question_type)
            else:
                raise NotImplementedError('Multiple VLM queries not implemented')
            
            if 'cot' in method:
                for j in range(i + 1, len(prompts)):
                    fqkey, fprompt, fchoices, fqtype = prompts[j]
                    answer_dummy_txt = "{QUESTION_KEY_ANSWER}".replace("QUESTION_KEY", question_key)
                    if isinstance(clean_answer, str):
                        clean_answer_with_dict = "{\"answer\": " + clean_answer + "}"
                    else:
                        clean_answer_with_dict = "{\"answer\": \"" + str(clean_answer) + "\"}"
                    fprompt = fprompt.replace(answer_dummy_txt, clean_answer_with_dict)
                    prompts[j] = (fqkey, fprompt, fchoices, fqtype)
            
            human_question_answers[question_key] = load_human_answer(human_answers_fp, question_key, choices, question_type)
            top_1_accuracy = compute_top_k_accuracy(question_answers[question_key], human_question_answers[question_key], 1)
            top_2_accuracy = compute_top_k_accuracy(question_answers[question_key], human_question_answers[question_key], 2)
            cross_entropy = compute_cross_entropy(question_answers[question_key], human_question_answers[question_key])
            vlm_entropy = -np.sum([p * np.log2(p) for p in question_answers[question_key].answers_probabilities if p > 0])
            human_entropy = -np.sum([p * np.log2(p) for p in human_question_answers[question_key].answers_probabilities if p > 0])
            top_1_human_accuracy = compute_top_k_human_accuracy(human_question_answers[question_key], 1)
            top_2_human_accuracy = compute_top_k_human_accuracy(human_question_answers[question_key], 2)
            top_1_random_accuracy = 1 / len(human_question_answers[question_key].choices)
            top_2_random_accuracy = 2 / len(human_question_answers[question_key].choices)

            eval_export[question_key] = {
                'top_1_accuracy': top_1_accuracy,
                'top_2_accuracy': top_2_accuracy,
                'top_1_human_accuracy': top_1_human_accuracy,
                'top_2_human_accuracy': top_2_human_accuracy,
                'top_1_random_accuracy': top_1_random_accuracy,
                'top_2_random_accuracy': top_2_random_accuracy,
                'cross_entropy': cross_entropy,
                'vlm_entropy': vlm_entropy,
                'human_entropy': human_entropy,
                'vlm_probabilities': question_answers[question_key].answer_to_probability,
                'human_probabilities': human_question_answers[question_key].answer_to_probability
            }
            
            if 'cot' in method:
                model.add_to_conversation_history(('user', prompt))
                model.add_to_conversation_history(('assistant', generated_text))
            
            with open(os.path.join(eval_sample_dir, 'question_answers.json'), 'w') as of:
                json.dump(question_answers_export, of, indent=4)

            common_answers_path = os.path.join(answer_dir, 'common_answers.json')
            human_common_answers_path = os.path.join(eval_sample_dir, 'human_common_answers.json')
            shutil.copy(common_answers_path, human_common_answers_path)

            eval_export[question_key]['vlm_probabilities'] = {str(k): v for k, v in eval_export[question_key]['vlm_probabilities'].items()}
            eval_export[question_key]['human_probabilities'] = {str(k): v for k, v in eval_export[question_key]['human_probabilities'].items()}
            with open(os.path.join(eval_sample_dir, 'evaluation.json'), 'w') as of:
                json.dump(eval_export, of, indent=4)
        
        n_samples_processed += 1


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--resume_folder', type=str, default=None, help='Folder to resume evaluation from')
    args = parser.parse_args()
    resume_folder = args.resume_folder
    if resume_folder:
        print(f"Resuming evaluation from {resume_folder}")
    
    config = load_yaml('eval_cfg.yaml')
    baseline_model = config['baseline_model']
    survey_folder = config['survey_folder']
    prompts_folder = config['prompts_folder']
    evaluation_folder = config['evaluation_folder']
    model_to_api_key = config['model_to_api_key']
    prompt_image_fp = config['prompt_image_fp']
    method = config['method']
    
    prompt_data_fp = os.path.join(prompts_folder, 'survey_prompt.json')
    samples_json_fp = os.path.join(prompts_folder, 'sample_info.json')
    relevant_prev_qs_fp = os.path.join(prompts_folder, 'relevant_prev_questions.json')
    
    dataset_cfg = config.get('dataset_cfg', None)
    
    evaluate_baseline(baseline_model=baseline_model,
                      survey_folder=survey_folder,
                      prompts_folder=prompts_folder,
                      evaluation_folder=evaluation_folder,
                      model_to_api_key=model_to_api_key,
                      prompt_image_fp=prompt_image_fp,
                      method=method,
                      prompts_data=prompt_data_fp,
                      samples_json_fp=samples_json_fp,
                      relevant_prev_qs_fp=relevant_prev_qs_fp,
                      config=config,
                      dataset_cfg=dataset_cfg,
                      debug=config.get('debug', False),
                      resume_folder=resume_folder)

if __name__ == "__main__":
    main()