import math
import os
import shutil
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence

import coolname
import hydra
import numpy as np
import pydantic
import torch
import torch.distributed as dist
import tqdm
import yaml
from adam_atan2 import AdamATan2
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
from omegaconf import DictConfig
from torch import nn
from torch.utils.data import DataLoader
from utils.functions import get_model_source_path, load_model_class

import wandb
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata


@torch.no_grad()
def board_accuracy(
    logits: torch.Tensor,  # (B, 81, 9) or (N, B, 81, 9)
    y: torch.Tensor,  # (B, 81) or (N, B, 81) in {0..8}
    filled: torch.Tensor,
) -> float:
    """
    Accuracy for blank cells
    """
    pred = logits.argmax(dim=-1)  # (B, 81) or (N, B) in {0..8}
    tgt = y.long()  # (B, 81) or (N, 81) in {0..8}

    correct = pred == tgt

    # Only consider blank cells
    solved_mask_blank = (filled | correct).all(dim=1)  # (B,)
    board_acc = solved_mask_blank.float().mean().item()

    return board_acc


@torch.no_grad()
def digit_accuracy(
    logits: torch.Tensor,  # (B, 81, 9)
    target: torch.Tensor,  # (B, 81) in {1..9}
):
    pred = logits.argmax(dim=-1)  # (B, 81)
    mask = target != -100  # (B, 81)
    correct = (pred[mask] == target[mask]).sum().item()
    total = mask.sum().item()
    acc = correct / total if total > 0 else 0.0
    return acc


class LossConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="allow")

    name: str


class ArchConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="allow")

    name: str
    loss: LossConfig


class PretrainConfig(pydantic.BaseModel):
    # Config
    arch: ArchConfig
    # Data
    data_path: str

    # Hyperparams
    global_batch_size: int
    epochs: int

    lr: float
    lr_min_ratio: float
    lr_warmup_steps: int

    weight_decay: float
    beta1: float
    beta2: float

    # Puzzle embedding
    puzzle_emb_lr: float
    puzzle_emb_weight_decay: float

    # Names
    project_name: Optional[str] = None
    run_name: Optional[str] = None
    checkpoint_path: Optional[str] = None

    # Extras
    seed: int = 0
    checkpoint_every_eval: bool = False
    eval_interval: Optional[int] = None
    eval_save_outputs: List[str] = []

    num_votes: int = 4


@dataclass
class TrainState:
    model: nn.Module
    optimizers: Sequence[torch.optim.Optimizer]
    optimizer_lrs: Sequence[float]
    carry: Any

    step: int
    total_steps: int


def create_dataloader(
    config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs
):
    dataset = PuzzleDataset(
        PuzzleDatasetConfig(
            seed=config.seed,
            dataset_path=config.data_path,
            rank=rank,
            num_replicas=world_size,
            **kwargs,
        ),
        split=split,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=None,
        num_workers=1,
        prefetch_factor=8,
        pin_memory=True,
        persistent_workers=True,
    )
    return dataloader, dataset.metadata


def create_model(
    config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int
):
    model_cfg = dict(
        **config.arch.__pydantic_extra__,  # type: ignore
        batch_size=config.global_batch_size // world_size,
        vocab_size=train_metadata.vocab_size,
        seq_len=train_metadata.seq_len,
        num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
        causal=False,  # Non-autoregressive
    )

    # Instantiate model with loss head
    model_cls = load_model_class(config.arch.name)
    loss_head_cls = load_model_class(config.arch.loss.name)

    with torch.device("cuda"):
        model: nn.Module = model_cls(model_cfg)
        model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__)  # type: ignore
        if "DISABLE_COMPILE" not in os.environ:
            model = torch.compile(model, dynamic=False)  # type: ignore

        # Broadcast parameters from rank 0
        if world_size > 1:
            with torch.no_grad():
                for param in list(model.parameters()) + list(model.buffers()):
                    dist.broadcast(param, src=0)

    # Optimizers and lr
    optimizers = [
        CastedSparseEmbeddingSignSGD_Distributed(
            model.model.puzzle_emb.buffers(),  # type: ignore
            lr=0,  # Needs to be set by scheduler
            weight_decay=config.puzzle_emb_weight_decay,
            world_size=world_size,
        ),
        AdamATan2(
            model.parameters(),
            lr=0,  # Needs to be set by scheduler
            weight_decay=config.weight_decay,
            betas=(config.beta1, config.beta2),
        ),
    ]
    optimizer_lrs = [config.puzzle_emb_lr, config.lr]

    return model, optimizers, optimizer_lrs


def cosine_schedule_with_warmup_lr_lambda(
    current_step: int,
    *,
    base_lr: float,
    num_warmup_steps: int,
    num_training_steps: int,
    min_ratio: float = 0.0,
    num_cycles: float = 0.5,
):
    if current_step < num_warmup_steps:
        return base_lr * float(current_step) / float(max(1, num_warmup_steps))

    progress = float(current_step - num_warmup_steps) / float(
        max(1, num_training_steps - num_warmup_steps)
    )
    return base_lr * (
        min_ratio
        + max(
            0.0,
            (1 - min_ratio)
            * 0.5
            * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
        )
    )


def init_train_state(
    config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int
):
    # Estimated total training steps
    total_steps = int(
        config.epochs
        * train_metadata.total_groups
        * train_metadata.mean_puzzle_examples
        / config.global_batch_size
    )

    # Model
    model, optimizers, optimizer_lrs = create_model(
        config, train_metadata, world_size=world_size
    )

    return TrainState(
        step=0,
        total_steps=total_steps,
        model=model,
        optimizers=optimizers,
        optimizer_lrs=optimizer_lrs,
        carry=None,
    )


def save_train_state(config: PretrainConfig, train_state: TrainState):
    # FIXME: Only saved model.
    if config.checkpoint_path is None:
        return

    os.makedirs(config.checkpoint_path, exist_ok=True)
    torch.save(
        train_state.model.state_dict(),
        os.path.join(config.checkpoint_path, f"step_{train_state.step}"),
    )


def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
    return cosine_schedule_with_warmup_lr_lambda(
        current_step=train_state.step,
        base_lr=base_lr,
        num_warmup_steps=round(config.lr_warmup_steps),
        num_training_steps=train_state.total_steps,
        min_ratio=config.lr_min_ratio,
    )


def train_batch(
    config: PretrainConfig,
    train_state: TrainState,
    batch: Any,
    global_batch_size: int,
    rank: int,
    world_size: int,
):
    train_state.step += 1
    if train_state.step > train_state.total_steps:  # At most train_total_steps
        return

    # To device
    batch = {k: v.cuda() for k, v in batch.items()}

    # Init carry if it is None
    if train_state.carry is None:
        with torch.device("cuda"):
            train_state.carry = train_state.model.initial_carry(batch)  # type: ignore

    # Forward
    train_state.carry, loss, metrics, _, _ = train_state.model(
        carry=train_state.carry, batch=batch, return_keys=[]
    )

    ((1 / global_batch_size) * loss).backward()

    # Allreduce
    if world_size > 1:
        for param in train_state.model.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad)

    # Apply optimizer
    lr_this_step = None
    for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
        lr_this_step = compute_lr(base_lr, config, train_state)

        for param_group in optim.param_groups:
            param_group["lr"] = lr_this_step

        optim.step()
        optim.zero_grad()

    # Reduce metrics
    if len(metrics):
        assert not any(v.requires_grad for v in metrics.values())

        metric_keys = list(
            sorted(metrics.keys())
        )  # Sort keys to guarantee all processes use the same order.
        # Reduce and reconstruct
        metric_values = torch.stack([metrics[k] for k in metric_keys])
        if world_size > 1:
            dist.reduce(metric_values, dst=0)

        if rank == 0:
            metric_values = metric_values.cpu().numpy()
            reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}

            # Postprocess
            count = max(reduced_metrics["count"], 1)  # Avoid NaNs
            reduced_metrics = {
                f"train/{k}": v / (global_batch_size if k.endswith("loss") else count)
                for k, v in reduced_metrics.items()
            }

            reduced_metrics["train/lr"] = lr_this_step
            return reduced_metrics


def evaluate(
    config: PretrainConfig,
    train_state: TrainState,
    eval_loader: torch.utils.data.DataLoader,
    eval_metadata: PuzzleDatasetMetadata,
    rank: int,
    world_size: int,
):
    print("num_votes:", config.num_votes)
    run = wandb.init(project="sudoku_deq_ebt", name=f"eval_hrm_{config.num_votes}")
    with torch.inference_mode():
        set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}

        all_preds = {}

        metric_keys = []
        metric_values = None
        metric_global_batch_size = [0 for _ in range(len(set_ids))]
        board_accs = []
        carry = None
        for set_name, batch, global_batch_size in tqdm.tqdm(eval_loader):
            # To device
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.device("cuda"):
                carry = train_state.model.initial_carry(batch)  # type: ignore

            # Forward
            while True:
                all_finish = False
                entropies = []
                metrics = []
                logits = []
                for _ in range(config.num_votes):
                    carry, _, metrics_, preds_, all_finish_ = train_state.model(
                        carry=carry, batch=batch, return_keys=config.eval_save_outputs
                    )
                    logits.append(preds_["logits"])
                    log_prob = torch.nn.functional.log_softmax(preds_["logits"], dim=-1)
                    prob = log_prob.exp()
                    entropy = -(log_prob * prob).sum(dim=-1).mean(-1)  # (B,)
                    entropies.append(entropy)
                    all_finish = all_finish or all_finish_.item()

                if all_finish:
                    break
            entropy = torch.stack(entropies, dim=0)  # (V, B)
            min_entropy_idx = torch.argmin(entropy, dim=0)  # (B,)
            logits = torch.stack(logits, dim=0)  # (V, B, L, C)
            logits = logits.permute(1, 0, 2, 3)[
                torch.arange(logits.shape[1], device=logits.device), min_entropy_idx
            ]  # (B, L, C)
            labels = batch["labels"]
            digit_acc = digit_accuracy(logits, labels)
            board_acc = board_accuracy(logits, labels, labels == -100)
            metrics = {
                "digit_acc_step": digit_acc,
                "board_acc_step": board_acc,
            }
            preds = {
                "logits": logits,
            }
            board_accs.append(board_acc)
            run.log(metrics)

            for collection in (batch, preds):
                for k, v in collection.items():
                    if k in config.eval_save_outputs:
                        all_preds.setdefault(k, [])
                        all_preds[k].append(
                            v.cpu()
                        )  # Move to CPU for saving GPU memory

            del carry, preds, batch, all_finish

            # # Aggregate
            # set_id = set_ids[set_name]

            # if metric_values is None:
            #     metric_keys = list(sorted(metrics.keys()))  # Sort keys to guarantee all processes use the same order.
            #     metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
        run.log({"board_acc_epoch": np.mean(board_accs)})

        #     metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
        #     metric_global_batch_size[set_id] += global_batch_size

        # if len(all_preds) and config.checkpoint_path is not None:
        #     all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}

        #     os.makedirs(config.checkpoint_path, exist_ok=True)
        #     torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))

        # # Logging
        # # Reduce to rank 0
        # if metric_values is not None:
        #     if world_size > 1:
        #         dist.reduce(metric_values, dst=0)

        #     if rank == 0:
        #         reduced_metrics = metric_values.cpu().numpy()
        #         reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
        #                            for set_id, set_name in enumerate(set_ids)}

        #         # Postprocess
        #         for set_name, metrics in reduced_metrics.items():
        #             count = metrics.pop("count")
        #             reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}

        #         return reduced_metrics


def save_code_and_config(config: PretrainConfig):
    if config.checkpoint_path is None or wandb.run is None:
        return

    os.makedirs(config.checkpoint_path, exist_ok=True)

    # Copy code
    code_list = [
        get_model_source_path(config.arch.name),
        get_model_source_path(config.arch.loss.name),
    ]
    for code_file in code_list:
        if code_file is not None:
            code_name = os.path.basename(code_file)

            shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))

    # Dump config as yaml
    config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
    with open(config_file, "wt") as f:
        yaml.dump(config.model_dump(), f)

    # Log code
    wandb.run.log_code(config.checkpoint_path)


def load_synced_config(
    hydra_config: DictConfig, rank: int, world_size: int
) -> PretrainConfig:
    objects = [None]
    if rank == 0:
        config = PretrainConfig(**hydra_config)  # type: ignore

        # Naming
        if config.project_name is None:
            config.project_name = (
                f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
            )
        if config.run_name is None:
            config.run_name = (
                f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
            )
        if config.checkpoint_path is None:
            config.checkpoint_path = os.path.join(
                "checkpoints", config.project_name, config.run_name
            )

        objects = [config]

    if world_size > 1:
        dist.broadcast_object_list(objects, src=0)

    return objects[0]  # type: ignore


@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
def launch(hydra_config: DictConfig):
    RANK = 0
    WORLD_SIZE = 1

    # Initialize distributed training if in distributed environment (e.g. torchrun)
    if "LOCAL_RANK" in os.environ:
        # Initialize distributed, default device and dtype
        dist.init_process_group(backend="nccl")

        RANK = dist.get_rank()
        WORLD_SIZE = dist.get_world_size()

        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

    # Load sync'ed config
    config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)

    # Seed RNGs to ensure consistency
    torch.random.manual_seed(config.seed + RANK)

    # Dataset
    train_epochs_per_iter = (
        config.eval_interval if config.eval_interval is not None else config.epochs
    )
    total_iters = config.epochs // train_epochs_per_iter

    assert (
        config.epochs % train_epochs_per_iter == 0
    ), "Eval interval must be a divisor of total epochs."

    train_loader, train_metadata = create_dataloader(
        config,
        "train",
        test_set_mode=False,
        epochs_per_iter=train_epochs_per_iter,
        global_batch_size=config.global_batch_size,
        rank=RANK,
        world_size=WORLD_SIZE,
    )
    eval_loader, eval_metadata = create_dataloader(
        config,
        "test",
        test_set_mode=True,
        epochs_per_iter=1,
        global_batch_size=config.global_batch_size,
        rank=RANK,
        world_size=WORLD_SIZE,
    )

    # Train state
    train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)

    # Progress bar and logger
    progress_bar = None
    if RANK == 0:
        progress_bar = tqdm.tqdm(total=train_state.total_steps)

        wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True))  # type: ignore
        wandb.log(
            {"num_params": sum(x.numel() for x in train_state.model.parameters())},
            step=0,
        )
        save_code_and_config(config)

    # Training Loop
    for _iter_id in range(total_iters):
        print(
            f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}"
        )

        ############ Train Iter
        train_state.model.train()
        for set_name, batch, global_batch_size in train_loader:
            metrics = train_batch(
                config,
                train_state,
                batch,
                global_batch_size,
                rank=RANK,
                world_size=WORLD_SIZE,
            )

            if RANK == 0 and metrics is not None:
                wandb.log(metrics, step=train_state.step)
                progress_bar.update(train_state.step - progress_bar.n)  # type: ignore

        ############ Evaluation
        train_state.model.eval()
        metrics = evaluate(
            config,
            train_state,
            eval_loader,
            eval_metadata,
            rank=RANK,
            world_size=WORLD_SIZE,
        )

        if RANK == 0 and metrics is not None:
            wandb.log(metrics, step=train_state.step)

        ############ Checkpointing
        if RANK == 0 and (
            config.checkpoint_every_eval or (_iter_id == total_iters - 1)
        ):
            save_train_state(config, train_state)

    # finalize
    if dist.is_initialized():
        dist.destroy_process_group()
    wandb.finish()


if __name__ == "__main__":
    launch()
