from tqdm import tqdm, trange
from ribs.archives import GridArchive
from ribs.visualize import grid_archive_heatmap

import numpy as np
import random
import sys
import matplotlib.pyplot as plt

import json
import pickle
import os

from init_mechanics import mech_1, mech_2, mech_3, mech_4, mech_5, mech_6, mech_7, mech_8
from init_games import game_class_1, game_class_2, make_game_1, make_game_2

from llm_emitter import MechanicLLMEmitter, GameLLMEmitter
from create_mechanics import mechanics_test
from create_games import get_games_scores
from llm_mcts import run_mechanic_mcts
from llm_proxy_client import LLMProxyClient

from utils import extract_function_name
import time

import ray

if not ray.is_initialized():
    ray.init(num_cpus=24)

from configs import Configs
configs = Configs()

# Add timing decorator
def timing_decorator(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
        return result
    return wrapper

def print_configs(configs):
    """
    Print all configuration parameters in a formatted way
    """
    print("\n" + "="*50)
    print(f"CONFIGURATION SETTINGS FOR EXPERIMENT: {configs.experiment}")
    print("="*50)
    
    # General settings
    print("\nGENERAL SETTINGS:")
    print(f"Model: {configs.model}")
    print(f"Experiment name: {configs.experiment}")
    print(f"Total generations: {configs.generations}")
    print(f"Max mechanics to add: {configs.max_num_mechs_to_add}")
    
    # MCTS settings
    print("\nMCTS SETTINGS:")
    print(f"LLM MCTS iterations: {configs.llm_mcts_iterations}")
    print(f"Simulation depth: {configs.simulation_depth}")
    
    # Archive settings
    print("\nARCHIVE SETTINGS:")
    print(f"No game archive: {configs.no_game_archive if hasattr(configs, 'no_game_archive') else 'Not specified'}")
    print(f"Random selection: {configs.random_selection if hasattr(configs, 'random_selection') else 'Not specified'}")
    
    # Other settings if available
    if hasattr(configs, 'additional_settings'):
        print("\nADDITIONAL SETTINGS:")
        for key, value in configs.additional_settings.items():
            print(f"{key}: {value}")
    
    print("\n" + "="*50 + "\n")

# Print all configuration parameters at startup
print_configs(configs)


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.floating):
            return float(obj)
        return super().default(obj)


os.makedirs(f'cache/{configs.experiment}', exist_ok=True)
with open(f'cache/{configs.experiment}/configs.json', 'w') as f:
    json.dump(configs.__dict__, f, indent=4, cls=NumpyEncoder)

MODEL = configs.model
EXPERIMENT = configs.experiment
total_itrs = configs.generations 
max_num_mechs = configs.max_num_mechs_to_add
init_games = 1

no_game_archive = True

random_selection = True


batch_size = 1

game_min_bound_1 = -1
game_max_bound_1 = 1
game_min_bound_2 = 0
game_max_bound_2 = 25



game_archive = GridArchive(solution_dim=1,
                      dims=(25, 25),
                      ranges=[(game_min_bound_1, game_max_bound_1), (game_min_bound_2, game_max_bound_2)],
                      dtype={"solution": np.dtype('O'), "objective": np.float32, "measures": np.float32})



mechanic_generation_error = 0
game_generation_error = 0
total_games_generated = 0
win_conditions = {}

total_iterations = 0
mcts_stats = {}
total_nodes = 0
total_games = 0

#TODO: LOAD ARCHIVE FOR LONG RUNS

if configs.load_checkpoint:
    with open(f'cache/{EXPERIMENT}/mechanic_archive.pkl', 'rb') as f:
        mechanic_archive = pickle.load(f)

    print("Loaded archive!: ", mechanic_archive)

    # Load final iteration from stats if available
    try:
        with open(f'cache/{EXPERIMENT}/all_stats.json', 'r') as f:
            stats = json.load(f)
            # Find the highest iteration number in the stats array
            final_iteration = max([item["iteration"] for item in stats]) if stats else 0
            itr_start = final_iteration + 1
            total_nodes = stats[-1]["game_creation_stats"]["total_nodes"]
            total_games = stats[-1]["game_creation_stats"]["total_games"]
            mechanic_generation_error = stats[-1]["errors"]["mechanic_generation_error_number"]
            #total_itrs = total_itrs + final_iteration
            print(f"Resuming from iteration {itr_start} / {total_itrs}")
            print(f"Resuming from Total Nodes: {total_nodes}")
            print(f"Resuming from Total Games: {total_games}")
            print(f"Resuming from Mechanic Generation Error: {mechanic_generation_error}")
    except (FileNotFoundError, json.JSONDecodeError):
        itr_start = 1

else:
    mechanic_archive = GridArchive(solution_dim=1,
                      dims=(configs.dim_1, configs.dim_2),
                      ranges=[(configs.min_bound_1, configs.max_bound_1), (configs.min_bound_2, configs.max_bound_2)],
                      dtype={"solution": np.dtype('O'), "objective": np.float32, "measures": np.float32})
    
    itr_start = 1

if not no_game_archive:
    with open(f'cache/{EXPERIMENT}/game_archive.pkl', 'rb') as f:
        game_archive = pickle.load(f)





mutation_individuals = configs.diversity_mutation_individuals

# Initialize LLM client
llm_client = LLMProxyClient()

mechanic_emitters = [
    MechanicLLMEmitter(
        mechanic_archive,
        initial_solutions=np.array([[mech_1], [mech_2], [mech_3], [mech_4], [mech_5], [mech_6], [mech_7], [mech_8]]),
        bounds=None,
        mutation_individuals=mutation_individuals,
        batch_size=batch_size,
        operator="openai",
        operator_kwargs={"temperature": 1.0},
        mutation_prompt="",
        model=MODEL
    )
]

@ray.remote
def _ask_mechanics_emmiters(selected_m_evo_operator, mech_emitter):
    start_time = time.time()
    mechanics_batch, _ = mech_emitter.ask(selected_m_evo_operator)
    end_time = time.time()
    print(f"Mechanics emitter ask took {end_time - start_time:.2f} seconds")
    return mechanics_batch[0]

if mechanic_emitters[0]._initial_solutions is not None and mechanic_archive.empty:
    sol_list = []
    for init_solution in mechanic_emitters[0]._initial_solutions:
        #print("\nINIT SOLUTION:\n", init_solution)
        #sol_list.append(init_solution[0])
        futures_1 = [mechanics_test.remote(sol, init_mechs=True,generation="initial") for sol in init_solution]
        results = ray.get(futures_1)
        
        for result in results:
            #print("\nRESULT:\n", result)
            objective, measure, _ = result      
        
        # Add all solutions to archive
        mechanic_archive.add([init_solution], objective, measure)



for itr in trange(itr_start, total_itrs + 1, file=sys.stdout, desc='Iterations'):
    generation_start_time = time.time()
    print(f"GENERATION {itr} / {total_itrs}")
    if itr == 0:
        continue

    # Time mechanic generation
    mechanic_gen_start = time.time()
    mechanics_evo_operators = ["mutation", "diversity_mutation", "crossover"]
    operator_weights = [0.3, 0.5, 0.2]
    selected_m_evo_operator = np.random.choice(mechanics_evo_operators, p=operator_weights)

    #print("\nselected_operator: ", selected_m_evo_operator)

    #print("mechanic_emitters[0].ask(selected_m_evo_operator): ", mechanic_emitters[0].ask(selected_m_evo_operator))
    evo_operators = []

    for i in range(configs.batch_size):
        selected_m_evo_operator = np.random.choice(mechanics_evo_operators, p=operator_weights)
        evo_operators.append(selected_m_evo_operator)

    #print(f"\nEVO OPERATORS: {evo_operators}")

    # Optimize the _process_mechanic function
    @ray.remote(num_cpus=24)
    def _process_mechanic(operator, mechanic_emitters, mechanic_archive):
        
        # Call the mechanics emitter directly
        mechanics_batch, _ = mechanic_emitters[0].ask(operator)
        mechanic = mechanics_batch[0]
        test_result = mechanics_test.remote(mechanic, generation=itr)
        return mechanic, test_result

    # Prepare emitter data for serialization
    emitter_data = {
        'initial_solutions': mechanic_emitters[0]._initial_solutions,
        'mutation_individuals': mechanic_emitters[0]._mutation_individuals,
        'batch_size': mechanic_emitters[0]._batch_size,
        'operator': mechanic_emitters[0]._operator,
        'operator_kwargs': mechanic_emitters[0]._operator_kwargs,
        'mutation_prompt': mechanic_emitters[0]._mutation_prompt,
        'model': mechanic_emitters[0]._model
    }

    mechanics_futures = [_process_mechanic.remote(operator, mechanic_emitters, mechanic_archive) for operator in evo_operators]
    results = ray.get(mechanics_futures)
    mechanic_gen_end = time.time()
    print(f"Mechanic generation took {mechanic_gen_end - mechanic_gen_start:.2f} seconds")

    # Unpack results
    mechanics_batch = [result[0] for result in results]
    #for i, mechanic in enumerate(mechanics_batch):
    #    print(f"Mechanic {i}: {mechanic}")
    
    values = [ray.get(result[1]) for result in results]  # Get the actual values from the futures

    for i, value in enumerate(values):
        print(f"Value {i}: {value}")
    
    successful_mechanics = []
    successful_mechanics_behaviours = []

    mechanics_behaviours = []
    updated_game_mech_classes = []
    for i, value in enumerate(values):
        mechanics_behaviours.append(value[1])
        updated_game_mech_classes.append(value[2])
        
        if mechanics_behaviours[i] is None or updated_game_mech_classes[i] is None:
            print("Errors detected in mechanics generation!")
            mechanic_generation_error += 1
        else:
            successful_mechanics.append(mechanics_batch[i])
            successful_mechanics_behaviours.append(mechanics_behaviours[i])

    


    #for number_of_mechs in max_num_mechs: #range(1,max_num_mechs):
    for mechanic in successful_mechanics:    
        print(f"EVALUATING MECHANIC {extract_function_name(mechanic)}")
    
    mcts_start = time.time()
    mcts_futures = [run_mechanic_mcts.remote(mechanic_emitters, mechanic_archive, mechanic, llm_mcts_iterations=configs.llm_mcts_iterations, simulation_depth=configs.simulation_depth, save_path=f'cache/{EXPERIMENT}', generation=itr) for mechanic in successful_mechanics]
    results_llm_mcts = ray.get(mcts_futures)
    mcts_end = time.time()
    print(f"MCTS execution took {mcts_end - mcts_start:.2f} seconds")

    print(f"results_llm_mcts: {results_llm_mcts}")

    mcts_stats = []
    all_nodes = []
    games_created = []
    node_mechanic_data = []
    shapley_fitness = []

    for result in results_llm_mcts:
        mcts_stats.append(result[0])
        all_nodes.append(result[1])
        games_created.append(result[2])
        node_mechanic_data.append(result[3])
        shapley_fitness.append(result[4])

    total_mcts_stats = {'total_nodes': 0, 'games_created': 0, 'avg_depth': 0, 'max_depth': 0, 'avg_visits': 0, 'max_visits': 0, 'root_value': 0, 'root_visits': 0}

    for mcts_result in mcts_stats:
        for key, value in mcts_result.items():
            total_mcts_stats[key] += value


    #print(f"nodes: {all_nodes}")
    #print(f"games_created: {games_created}")
    for nodes in all_nodes:
        total_nodes += nodes
    for game in games_created:
        total_games += game

    # Time archive updates
    archive_start = time.time()
    for i, mechanic in enumerate(successful_mechanics):    
        mechanic_archive.add([[mechanic]], [shapley_fitness[i]], successful_mechanics_behaviours[i])
        print(f"Added mechanic {extract_function_name(mechanic)} to archive")

    #for data in node_mechanic_data:
    #    print(f"Node Mechanic Data: {data}")

    for i, specific_node_data in enumerate(node_mechanic_data):
        if specific_node_data:
            for data in specific_node_data:
                if data:
                    m = data["mechanic"]
                    mechanic_archive.add([[m]], [data["shapley_fitness"]], data["behaviour"])
                    print(f"Added mechanic {extract_function_name(m)} that was created in evaluation MCTS to archive")
    archive_end = time.time()
    print(f"Archive updates took {archive_end - archive_start:.2f} seconds")

    #if itr % 1 == 0 or itr == total_itrs:
    tqdm.write(f"\nIteration {itr:5d} \n "
                    f"Mechanic Archive Coverage: {mechanic_archive.stats.coverage * 100:6.3f}% \n"
                    f"Mechanic Normalized QD Score: {mechanic_archive.stats.norm_qd_score:6.3f} \n "
                    f"Individuals In Mechanics Archive: {len(mechanic_archive)}")
                    #f"Game Archive Coverage: {game_archive.stats.coverage * 100:6.3f}%  "
                    #f"Game Normalized QD Score: {game_archive.stats.norm_qd_score:6.3f}"
                    #f"Individuals In Game Archive: {len(game_archive)}")
    viz_start = time.time()
    plt.figure(figsize=(8, 6))
    plot = grid_archive_heatmap(mechanic_archive, vmin=-0.5, vmax=0.5)
    plt.savefig(f'cache/{EXPERIMENT}/mechanics_archive_heatmap_{itr}.png')
    viz_end = time.time()
    print(f"Visualization took {viz_end - viz_start:.2f} seconds")

    if itr % 1 == 0:
        with open(f'cache/{EXPERIMENT}/mechanic_archive_{itr}.pkl', 'wb') as f:
            pickle.dump(mechanic_archive, f)

    #with open(f'cache\\{EXPERIMENT}\\mechanic_archive.json', 'w') as f:
    #    json.dump(mechanic_archive, f, indent=4, cls=NumpyEncoder)

    total_iterations = itr

    stats_start = time.time()
    stats = {
        'iteration': total_iterations,
        'mechanic_archive': {
            'num_elites': mechanic_archive._stats.num_elites,
            'coverage': mechanic_archive._stats.coverage,
            'qd_score': mechanic_archive._stats.qd_score,
            'norm_qd_score': mechanic_archive._stats.norm_qd_score,
            'max_fitness': mechanic_archive._stats.obj_max,
            'mean_fitness': mechanic_archive._stats.obj_mean,
        },
        'game_creation_stats': {
            'total_nodes': total_nodes,
            'total_games': total_games,
            'success_rate': (total_games / total_nodes) * 100 if total_nodes > 0 else 0
        },
        'mcts_stats': total_mcts_stats
    }

    stats['errors'] = {
        'mechanic_generation_errors_percentage': (mechanic_generation_error / total_iterations) * 100,
        'mechanic_generation_error_number': mechanic_generation_error
    }

    print("\nFinal Statistics:")
    print("=" * 50)
    print(f"Total Iterations: {stats['iteration']}")
    
    print("\nMechanic Archive Stats:")
    print("-" * 30)
    print(f"Number of Elites: {stats['mechanic_archive']['num_elites']}")
    print(f"Coverage: {stats['mechanic_archive']['coverage']*100:.2f}%")
    print(f"QD Score: {stats['mechanic_archive']['qd_score']:.2f}")
    print(f"Normalized QD Score: {stats['mechanic_archive']['norm_qd_score']:.2f}")
    print(f"Max Fitness: {stats['mechanic_archive']['max_fitness']:.2f}")
    print(f"Mean Fitness: {stats['mechanic_archive']['mean_fitness']:.2f}")
    
    print("\nGame Creation Stats:")
    print("-" * 30)
    print(f"Total Nodes Explored: {stats['game_creation_stats']['total_nodes']}")
    print(f"Total Games Created: {stats['game_creation_stats']['total_games']}")
    print(f"Success Rate: {stats['game_creation_stats']['success_rate']:.2f}%")
    
    print("\nError Stats:")
    print("-" * 30)
    print(f"Mechanic Generation Errors: {stats['errors']['mechanic_generation_error_number']}")
    print(f"Error Rate: {stats['errors']['mechanic_generation_errors_percentage']:.2f}%")
    print("=" * 50)



    try:
        with open(f'cache/{EXPERIMENT}/all_stats.json', 'r') as f:
            all_stats = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        all_stats = []
    
    # Append new stats
    all_stats.append(stats)
    
    # Save all stats
    with open(f'cache/{EXPERIMENT}/all_stats.json', 'w') as f:
        json.dump(all_stats, f, indent=4, cls=NumpyEncoder)
    stats_end = time.time()
    print(f"Stats collection and saving took {stats_end - stats_start:.2f} seconds")

    generation_end_time = time.time()
    generation_duration = generation_end_time - generation_start_time
    print(f"\nGeneration {itr} timing breakdown:")
    print(f"Total generation time: {generation_duration:.2f} seconds")
    print(f"- Mechanic generation: {mechanic_gen_end - mechanic_gen_start:.2f}s ({(mechanic_gen_end - mechanic_gen_start)/generation_duration*100:.1f}%)")
    print(f"- MCTS execution: {mcts_end - mcts_start:.2f}s ({(mcts_end - mcts_start)/generation_duration*100:.1f}%)")
    print(f"- Archive updates: {archive_end - archive_start:.2f}s ({(archive_end - archive_start)/generation_duration*100:.1f}%)")
    print(f"- Visualization: {viz_end - viz_start:.2f}s ({(viz_end - viz_start)/generation_duration*100:.1f}%)")
    print(f"- Stats collection: {stats_end - stats_start:.2f}s ({(stats_end - stats_start)/generation_duration*100:.1f}%)")
    print(f"- Other operations: {generation_duration - (mechanic_gen_end - mechanic_gen_start) - (mcts_end - mcts_start) - (archive_end - archive_start) - (viz_end - viz_start) - (stats_end - stats_start):.2f}s")



plt.figure(figsize=(8, 6))
plot = grid_archive_heatmap(mechanic_archive, vmin=-1, vmax=1)
plt.savefig(f'cache/{EXPERIMENT}/mechanics_archive_heatmap_final.png')

with open(f'cache/{EXPERIMENT}/mechanic_archive.pkl', 'wb') as f:
    pickle.dump(mechanic_archive, f)

if not no_game_archive:
    with open(f'cache/{EXPERIMENT}/game_archive.pkl', 'wb') as f:
        pickle.dump(game_archive, f)




