import pickle

import numpy as np
import torch
import tqdm
import uuid
import os
from absl import app, flags, logging
from rich.jupyter import print as rprint
from rich.table import Table

from td.environments import Environment, environments
from td.learning.compose_flow_engine import sample_model_compose_kv, unflatten
from td.learning.gpt import TransformerConfig, TreeDiffusion
from td.learning.tokenizer import Tokenizer

np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)


flags.DEFINE_string("problem_filename", None, "Problem filename to evaluate.")
flags.DEFINE_string("checkpoint_name", None, "Checkpoint name to evaluate.")
flags.DEFINE_string("evaluation_dir", "evals", "Evaluation directory to save results.")
flags.DEFINE_integer("max_steps", 10, "Maximum number of steps per attempt.")
flags.DEFINE_integer("num_attempts_per_target", 500, "Number of attempts per target.")
FLAGS = flags.FLAGS


def generate_uuid():
    return str(uuid.uuid4())


def main(argv):
    checkpoint_name = FLAGS.checkpoint_name
    logging.info(f"Loading checkpoint from {checkpoint_name}")
    with open(checkpoint_name, "rb") as f:
        state = pickle.load(f)

    config = state["config"]

    table = Table(title="Config")
    table.add_column("Key", style="cyan")
    table.add_column("Value", style="white")
    for k, v in config.items():
        table.add_row(k, str(v))
    rprint(table)

    env_name = config["env"]
    image_model = config["image_model"]
    d_model = config["d_model"]
    n_layers = config["n_layers"]
    num_heads = config["num_heads"]
    max_sequence_length = config["max_sequence_length"]

    env: Environment = environments[env_name]()
    tokenizer = Tokenizer(
        env.grammar,
        max_token_length=max_sequence_length,
        max_sequence_length=max_sequence_length,
    )

    model = TreeDiffusion(
        TransformerConfig(
            vocab_size=tokenizer.vocabulary_size,
            max_seq_len=tokenizer.max_sequence_length,
            n_layer=n_layers,
            n_head=num_heads,
            n_embd=d_model,
        ),
        input_channels=env.compiled_shape[-1],
        image_model_name=image_model,
    )

    model.load_state_dict(state["model"])
    model.cuda()

    with open(FLAGS.problem_filename, "rb") as f:
        target_expressions = pickle.load(f)

    if not os.path.exists(FLAGS.evaluation_dir):
        os.makedirs(FLAGS.evaluation_dir)

    local_run_id = generate_uuid()
    logging.info(f"Local run id: {local_run_id}")

    save_filename = os.path.join(FLAGS.evaluation_dir, f"{local_run_id}.pkl")

    target_images = np.array([env.compile(e) for e in target_expressions])
    target_images_torch = torch.tensor(target_images).float().cuda().permute(0, 3, 1, 2)

    num_attempts_per_target = FLAGS.num_attempts_per_target
    max_steps = FLAGS.max_steps

    steps_to_solve = np.zeros(len(target_expressions)) + np.inf

    for problem_i in range(len(target_expressions)):
        logging.info(f"Problem {problem_i}")


        target_images_batched = (
            target_images_torch[problem_i]
            .unsqueeze(0)
            .repeat(num_attempts_per_target, 1, 1, 1)
        )
        current = [[] for _ in range(num_attempts_per_target)]

        # Expression matrix is attempt x step
        expression_matrix = [[""] * max_steps for _ in range(num_attempts_per_target)]

        for step in tqdm.trange(max_steps):
            new = sample_model_compose_kv(
                model,
                env,
                tokenizer,
                current,
                target_images_batched,
                temperature=0.1,
            )
            new_expressions = [unflatten(e[1:]) for e in new]
            for i, e in enumerate(new_expressions):
                expression_matrix[i][step] = e

            current = new

        # Once the expression matrix is fully populated, we can check for the goal, top-left to bottom-right.
        steps_used = 0
        for i in range(num_attempts_per_target):
            for j in range(max_steps):
                try:
                    compiled = env.compile(expression_matrix[i][j])
                    goal_reached = env.goal_reached(compiled, target_images[problem_i])
                    if goal_reached:
                        steps_to_solve[problem_i] = steps_used
                        break
                except Exception as e:
                    logging.error(f"Error: {e}")
                steps_used += 1

            if np.isfinite(steps_to_solve[problem_i]):
                break

        logging.info(f"Steps to solve: {steps_to_solve[problem_i]}")

        with open(save_filename, "wb") as f:
            pickle.dump(
                {
                    "steps_to_solve": steps_to_solve,
                    "seen_so_far": problem_i + 1,
                },
                f,
            )

        logging.info(
            f"Solved so far: {np.sum(steps_to_solve < np.inf) / (problem_i + 1):.2f} ({np.sum(steps_to_solve < np.inf)}/{problem_i + 1})"
        )


if __name__ == "__main__":
    app.run(main)
