import os
import pickle
import random
import uuid

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from absl import app, flags, logging
from torch.utils.data import DataLoader

from td.environments import environments
from td.learning.compose_flow_engine import ComposeFlowDataset
from td.learning.gpt import TransformerConfig, TreeDiffusion
from td.learning.tokenizer import Tokenizer

FLAGS = flags.FLAGS
flags.DEFINE_integer("batch_size", 256, "Batch size")
flags.DEFINE_integer("print_every", 100, "Print every")
flags.DEFINE_integer("training_steps", -1, "Training steps")
flags.DEFINE_string("checkpoint_dir", "checkpoints", "Checkpoint directory")
flags.DEFINE_integer("checkpoint_steps", 10000, "Checkpoint steps")
flags.DEFINE_integer("num_workers", 16, "Number of workers for data loading")
flags.DEFINE_float("learning_rate", 3e-4, "Learning rate")
flags.DEFINE_bool("wandb", True, "Log to wandb")
flags.DEFINE_string("env", "csg2da", "Environment to use")
flags.DEFINE_integer("max_sequence_length", 512, "Maximum sequence length")
flags.DEFINE_integer("min_primitives", 2, "Minimum number of primitives")
flags.DEFINE_integer("max_primitives", 8, "Maximum number of primitives")
flags.DEFINE_integer("n_layers", 3, "Number of layers")
flags.DEFINE_integer("d_model", 128, "Model dimension")
flags.DEFINE_integer("num_heads", 8, "Number of heads")
flags.DEFINE_string("device", "cuda", "Device to use")
flags.DEFINE_string("image_model", "nf_resnet26", "Vision model to use")
flags.DEFINE_string("resume_from", None, "Resume from checkpoint")
flags.DEFINE_bool(
    "target_observation", False, "Use observation compiler for target image."
)


def loss_fn(model, batch):
    tokens, mask, target_images, mutated_images = batch

    logits = model(tokens, target_images, mutated_images)

    logits = logits[:, :-1]
    targets = tokens[:, 1:]
    mask = mask[:, 1:]

    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)), targets.reshape(-1), reduction="none"
    )
    loss = loss.reshape(targets.shape)
    loss = (loss * mask).sum() / mask.sum()

    return loss, (loss * 0.0,)


def generate_uuid():
    return str(uuid.uuid4())


def batch_to_torch(batch, device="cpu"):
    tokens, mask, target_images, mutated_images = batch

    return (
        tokens.to(device).long(),
        mask.to(device).float(),
        target_images.to(device).float(),
        mutated_images.to(device).float(),
    )


def main(argv):
    assert FLAGS.env == "csg2da", "Only csg2da is supported for now."

    env = environments[FLAGS.env]()
    tokenizer = Tokenizer(
        env.grammar,
        max_token_length=FLAGS.max_sequence_length,
        max_sequence_length=FLAGS.max_sequence_length,
    )

    random.seed(1)

    local_run_id = FLAGS.resume_from or generate_uuid()
    checkpoint_dir = (
        os.path.join(FLAGS.checkpoint_dir, local_run_id)
        if FLAGS.checkpoint_dir
        else None
    )
    step = 0

    config = {
        "notes": "compose-flow",
        "batch_size": FLAGS.batch_size,
        "learning_rate": FLAGS.learning_rate,
        "env": FLAGS.env,
        "local_run_id": local_run_id,
        "max_sequence_length": FLAGS.max_sequence_length,
        "max_primitives": FLAGS.max_primitives,
        "min_primitives": FLAGS.min_primitives,
        "n_layers": FLAGS.n_layers,
        "d_model": FLAGS.d_model,
        "num_heads": FLAGS.num_heads,
        "image_model": FLAGS.image_model,
        "target_observation": FLAGS.target_observation,
    }

    if FLAGS.wandb:
        wandb.init(
            project="tree-diffusion",
            config=config,
        )

    model = TreeDiffusion(
        TransformerConfig(
            vocab_size=tokenizer.vocabulary_size,
            max_seq_len=tokenizer.max_sequence_length,
            n_layer=FLAGS.n_layers,
            n_head=FLAGS.num_heads,
            n_embd=FLAGS.d_model,
        ),
        input_channels=env.compiled_shape[-1],
        image_model_name=FLAGS.image_model,
    ).to(FLAGS.device)

    if os.path.exists(checkpoint_dir):
        checkpoint_files = [
            f
            for f in os.listdir(checkpoint_dir)
            if f.endswith(".pt") and f.startswith(f"{FLAGS.env}_step_")
        ]
        if checkpoint_files:
            latest_checkpoint = max(
                checkpoint_files, key=lambda x: int(x.split("_")[-1].split(".")[0])
            )
            checkpoint_filename = os.path.join(checkpoint_dir, latest_checkpoint)
            with open(checkpoint_filename, "rb") as f:
                state = pickle.load(f)
                if "model" in state:
                    model.load_state_dict(state["model"])
                else:
                    model.load_state_dict(state)
            step = int(latest_checkpoint.split("_")[-1].split(".")[0])
            logging.info(
                f"Loaded checkpoint from {checkpoint_filename}, starting at step {step}"
            )

    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.learning_rate)

    logging.info("Starting to train!")

    if (
        checkpoint_dir
        and FLAGS.checkpoint_steps > 0
        and not os.path.exists(checkpoint_dir)
    ):
        logging.info(
            f"Local run ID: {local_run_id}, saving checkpoints to {checkpoint_dir}"
        )
        os.makedirs(checkpoint_dir)

    batch_metrics = []

    batch_size = FLAGS.batch_size
    env_name = FLAGS.env
    max_sequence_length = FLAGS.max_sequence_length
    min_primitives = FLAGS.min_primitives
    max_primitives = FLAGS.max_primitives

    dataset = ComposeFlowDataset(
        batch_size=batch_size,
        env_name=env_name,
        max_sequence_length=max_sequence_length,
        min_primitives=min_primitives,
        max_primitives=max_primitives,
        target_observation=FLAGS.target_observation,
    )

    dataloader = DataLoader(dataset, batch_size=None, num_workers=FLAGS.num_workers)

    model.train()

    for batch in dataloader:
        batch = batch_to_torch(batch, FLAGS.device)
        loss, aux = loss_fn(model, batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        position_accuracy = aux[0]

        batch_metrics.append((loss, position_accuracy))

        if step % FLAGS.print_every == FLAGS.print_every - 1:
            metrics = np.mean(
                np.array(torch.tensor(batch_metrics).detach().cpu()), axis=0
            )
            mean_loss, mean_position_accuracy = metrics
            logging.info(f"Step {step + 1}, Loss: {mean_loss:.4f}")

            if FLAGS.wandb:
                wandb.log(
                    {
                        "loss": mean_loss,
                    },
                    step=step + 1,
                )

            batch_metrics.clear()

        if (
            checkpoint_dir
            and FLAGS.checkpoint_steps > 0
            and step % FLAGS.checkpoint_steps == FLAGS.checkpoint_steps - 1
        ):
            checkpoint_filename = os.path.join(
                checkpoint_dir, f"{env_name}_step_{step + 1}.pt"
            )
            if os.path.exists(checkpoint_filename):
                logging.warning(
                    f"Checkpoint file {checkpoint_filename} already exists, skipping."
                )
            else:
                with open(checkpoint_filename, "wb") as f:
                    pickle.dump({"model": model.state_dict(), "config": config}, f)
                logging.info(f"Checkpointed state to {checkpoint_filename}")

        step += 1

        if FLAGS.training_steps > 0 and step >= FLAGS.training_steps:
            break


if __name__ == "__main__":
    app.run(main)
