import json
import os
from agent.agent import FunctionCallAgent
from config.argument_parser import config

from memory.experience_memory import Memory
from streaming import MDSWriter
from help_functions import *
from tqdm import tqdm
import random
import argparse

negative_phrases = ['Early stop', 'cannot', 'not found', 'not available', "can't"]
columns = {
    'messages': 'json',
    'response': 'str',
    'task_id': 'str',
    'task_description': 'str',
    'similar_trajectories': 'json',
    'recent_trajectory': 'json'
}
SAVE_PATH = 'GUI-Agent/test_mds'
training_data_path = 'training_data'

def process_single_file(file_path, dataset, domain, model, memory, out):
    with open(file_path, 'r') as f:
        data = json.load(f)
        question = data['task_description']
        task_id = data['conversation_id']
        final_answer = data['response']
        if any(phrase in final_answer for phrase in negative_phrases):
            print('Invalid Answer')
            return 0
        conversation_data = data['rounds']
        
        # Use the new Memory class for multimodal retrieval
        # Get current image from the first round for multimodal matching
        current_image = None
        if len(conversation_data) > 0:
            current_image = get_base64_image_from_conversation(conversation_data[0])
        
        # Retrieve similar conversations using the new Memory class
        question = f"{dataset}_{domain}: {question}"
        selected_conversations = memory.retrieve_similar_conversations(
            question, 
            current_image=current_image,
            model=model,
            similar_num=5
        )
        
        similar_trajectories = []
        for conversation in selected_conversations:
            with open(conversation, 'r') as f:
                similar_conversation_data = json.load(f)
                if len(similar_conversation_data) < 2:
                    continue
                similar_actions_list, similar_images_list = organize_similar_tajectory(similar_conversation_data)
                similar_trajectories.append({
                    'actions': similar_actions_list,
                    'images': similar_images_list
                })
                if len(similar_trajectories) == 3:
                    break
        
        history_trajectories = []
        for idx, single_round in enumerate(conversation_data):
            if len(history_trajectories) < 3:
                recent_trajectory = history_trajectories
            else:
                recent_trajectory = history_trajectories[-3:]
            if similar_trajectories != []:
                sample = {
                'messages': single_round['messages'],
                'response': single_round['response'],
                'task_id': f'{task_id}_{idx}',
                'task_description': question,
                'similar_trajectories': similar_trajectories,
                # 'recent_trajectory': recent_trajectory
                'recent_trajectory': []
                }
                out.write(sample)
            
            image = get_base64_image_from_conversation(single_round)
            action = single_round['response']
            if image and isinstance(action, str):
                history_trajectories.append({
                    'actions': [action],
                    'images': [image]
                })
        return 1
    
    
def main():
    # Parse command line arguments for memory configuration
    parser = argparse.ArgumentParser(description='Prepare training data with multimodal memory matching')
    parser.add_argument('--multimodal', type=bool, default=True, 
                       help='Enable multimodal memory matching (text + image)')
    parser.add_argument('--faiss_index_path', type=str, default="",
                       help='Path to existing FAISS index to load (optional)')
    parser.add_argument('--save_path', type=str, default=SAVE_PATH,
                       help='Path to save the prepared data')
    parser.add_argument('--training_data_path', type=str, default=training_data_path,
                       help='Path to training data directory')
    
    args = parser.parse_args()
    
    # Initialize the Memory class with multimodal capabilities
    print(f"Initializing Memory class with multimodal={args.multimodal}")
    agent = FunctionCallAgent(config())
    memory = Memory(
        agent=agent,
        training_data_path=args.training_data_path,
        faiss_index_path=args.faiss_index_path,
        multimodal=args.multimodal
    )
    
    # Print available datasets and domains
    available_datasets = memory.get_available_datasets_and_domains()
    print("Available datasets and domains:")
    for dataset, domains in available_datasets.items():
        print(f"{dataset}: {domains}")
    
    all_datasets = os.listdir(args.training_data_path)
    all_datasets = [dataset for dataset in all_datasets if os.path.isdir(os.path.join(args.training_data_path, dataset))]

    with MDSWriter(out=args.save_path,
                        columns=columns, 
                        compression=None) as out:
        for dataset in tqdm(all_datasets):
            all_domains = os.listdir(f'{args.training_data_path}/{dataset}')
            all_domains = [domain for domain in all_domains if os.path.isdir(f'{args.training_data_path}/{dataset}/{domain}')]
            for domain in tqdm(all_domains):
                try:
                    all_tests = os.listdir(f'{args.training_data_path}/{dataset}/{domain}/qwen2.5-vl-32b')
                    all_tests = [test for test in all_tests if os.path.isdir(f'{args.training_data_path}/{dataset}/{domain}/qwen2.5-vl-32b/{test}')]
                except Exception as e:
                    print(f'Error listing tests for {dataset} {domain}: {e}')
                    continue
                seen_configs = set()
                for test in tqdm(all_tests):
                    if 'test' not in test:
                        continue
                    success_files = os.listdir(f'{args.training_data_path}/{dataset}/{domain}/qwen2.5-vl-32b/{test}/success')
                    positive_files = os.listdir(f'{args.training_data_path}/{dataset}/{domain}/qwen2.5-vl-32b/{test}/positive')
                    all_files = [f'success/{file}' for file in success_files] + [f'positive/{file}' for file in positive_files]
                    all_files = [file for file in all_files if file.endswith('.jsonl')]
                    random.shuffle(all_files)
                    print('*'*50, f'{dataset} {domain} {test}', '*'*50)
                    valid_files = 0
                    for file in tqdm(all_files):
                        if file in seen_configs:
                            continue
                        try:
                            file_path = f'{args.training_data_path}/{dataset}/{domain}/qwen2.5-vl-32b/{test}/{file}'
                            valid_files += process_single_file(file_path, dataset, domain, 'qwen2.5-vl-32b', memory, out)
                            seen_configs.add(file)
                        except Exception as e:
                            print(f'Error processing file {file}: {e}')
                            continue
                        
                    print(f'Valid files: {valid_files}')
                    

if __name__ == '__main__':
    main()