from absl import app
from absl import flags
from absl import logging

from td.environments import Environment, environments
from td.learning.tokenizer import Tokenizer
from td.learning.gpt import TreeDiffusion, TransformerConfig
from td.learning.evaluation import AREvaluator, ar_decoder
from td.samplers import ConstrainedRandomSampler

import pickle
import numpy as np
import os
import uuid
import torch
import tqdm

flags.DEFINE_string("checkpoint_name", None, "Path to the checkpoint to evaluate")
flags.DEFINE_string("problem_filename", None, "Name of the test set to evaluate")
flags.DEFINE_integer("max_expansions", 5000, "Maximum number of expansions to evaluate")
flags.DEFINE_integer("evaluation_batch_size", 256, "Batch size for evaluation")
flags.DEFINE_string("evaluation_dir", "evals", "Evaluations directory")
flags.DEFINE_string("device", "cuda", "Device to use")

FLAGS = flags.FLAGS


def ar_evaluation(
    checkpoint_name,
    num_problems: int = 256,
    problem_size: int = 10,
    temperature=0.1,
    max_expansions: int = 100,
    evaluation_batch_size: int = 16,
    save_filename: str = None,
    device="cuda",
):
    with open(checkpoint_name, "rb") as f:
        state = pickle.load(f)

    config = state["config"]

    env_name = config["env"]
    image_model = config["image_model"]
    d_model = config["d_model"]
    n_layers = config["n_layers"]
    num_heads = config["num_heads"]
    max_sequence_length = config["max_sequence_length"]
    target_observation = config["target_observation"]

    for key, value in config.items():
        logging.info(f"{key}: {value}")

    env: Environment = environments[env_name]()
    sampler = ConstrainedRandomSampler(env.grammar)
    tokenizer = Tokenizer(
        env.grammar,
        max_token_length=max_sequence_length,
        max_sequence_length=max_sequence_length,
    )

    model = TreeDiffusion(
        TransformerConfig(
            vocab_size=tokenizer.vocabulary_size,
            max_seq_len=tokenizer.max_sequence_length,
            n_layer=n_layers,
            n_head=num_heads,
            n_embd=d_model,
        ),
        input_channels=env.compiled_shape[-1],
        image_model_name=image_model,
    )

    model.load_state_dict(state["model"])
    model.to(device)

    with open(FLAGS.problem_filename, "rb") as f:
        hard_tests = pickle.load(f)
    num_problems = len(hard_tests)

    target_images = np.array(
        [
            env.compile(e) if not target_observation else env.compile_observation(e)
            for e in hard_tests
        ]
    )

    target_images_torch = (
        torch.tensor(target_images).to(FLAGS.device).float().permute(0, 3, 1, 2)
    )

    evaluation_batch_size = FLAGS.evaluation_batch_size
    steps_to_solve = np.zeros(num_problems) + np.inf

    for problem_i in range(num_problems):
        logging.info(f"Problem {problem_i} / {num_problems}")
        batch_target_image = target_images_torch[problem_i : problem_i + 1].repeat(
            evaluation_batch_size, 1, 1, 1
        )

        for i in tqdm.trange(0, max_expansions, evaluation_batch_size):
            current_batch_size = min(evaluation_batch_size, max_expansions - i)
            predictions = ar_decoder(
                model,
                env,
                tokenizer,
                config["num_image_tokens"],
                batch_target_image[:current_batch_size],
                temperature=temperature,
            )

            for j in range(current_batch_size):
                compiled = env.compile(predictions[j])
                goal_reached = env.goal_reached(compiled, target_images[problem_i])
                if goal_reached:
                    steps_to_solve[problem_i] = i + j
                    break

            if np.isfinite(steps_to_solve[problem_i]):
                break

        solved_so_far = np.sum(np.isfinite(steps_to_solve)) / (problem_i + 1)
        logging.info(f"Solved so far: {solved_so_far:.2f}")

        with open(save_filename, "wb") as f:
            pickle.dump(
                {
                    "steps_to_solve": steps_to_solve,
                    "seen_so_far": problem_i + 1,
                },
                f,
            )

    return steps_to_solve


def generate_uuid():
    return str(uuid.uuid4())


def main(argv):
    logging.info(f"Evaluating {FLAGS.checkpoint_name}")

    if not os.path.exists(FLAGS.evaluation_dir):
        os.makedirs(FLAGS.evaluation_dir)

    local_run_id = generate_uuid()
    logging.info(f"Local run id: {local_run_id}")

    save_filename = os.path.join(FLAGS.evaluation_dir, f"{local_run_id}.pkl")

    ar_evaluation(
        FLAGS.checkpoint_name,
        max_expansions=FLAGS.max_expansions,
        evaluation_batch_size=FLAGS.evaluation_batch_size,
        device=FLAGS.device,
        save_filename=save_filename,
    )


if __name__ == "__main__":
    app.run(main)
