from typing import Literal, Tuple

import einops
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from ema_pytorch import EMA
from torchmetrics.classification import MulticlassCalibrationError

from ds import (
    ARC2Dataset,
    ARCDataset,
    ExtremeSudokuAugDataset,
    HardSudokuDataset,
    MazeDataset,
    SudokuDataset,
)
from gta import make_2dcoord, make_SO2mats, rep_mul_x
from mlp import SwiGLU
from model_utils import board_accuracy, digit_accuracy

torch._dynamo.config.capture_scalar_outputs = True


class CrossAttention(nn.Module):
    """
    Cross-Attention (conv/fc両対応) + GTA/ROPE回転。
    - forward(x_q, x_kv, attn_mask=None, key_padding_mask=None)
    - x_q: クエリ入力、x_kv: キー/バリュー入力
      conv: x_* shape = (B, C=ch, H, W)
      fc  : x_* shape = (B, K_or_Q, C=ch)  ※self.hw で基準HWを渡す
    """

    def __init__(
        self,
        ch: int,
        heads: int = 8,
        weight: str = "conv",  # "conv" or "fc"
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        gta: bool = False,
        rope: bool = False,
        hw=None,  # 基準(H,W) or [H,W]（回転場の初期サイズ）
    ):
        super().__init__()

        assert weight in ("conv", "fc")
        assert int(gta) + int(rope) <= 1  # どちらか片方
        self.heads = heads
        self.head_dim = ch // heads
        self.weight = weight
        self.stride = stride
        self.gta = gta
        self.rope = rope
        self.hw = hw

        # Projection layers
        if weight == "conv":
            self.W_q = nn.Conv2d(
                ch, ch, kernel_size=kernel_size, stride=stride, padding=padding
            )
            self.W_kv = nn.Conv2d(
                ch, 2 * ch, kernel_size=kernel_size, stride=stride, padding=padding
            )
            self.W_o = nn.Conv2d(ch, ch, kernel_size=kernel_size, stride=1, padding=0)
        else:
            self.W_q = nn.Linear(ch, ch)
            self.W_kv = nn.Linear(ch, 2 * ch)
            self.W_o = nn.Linear(ch, ch)

        # 回転行列パラメータ（基準サイズ→実行時サイズに補間）
        if gta or rope:
            assert hw is not None, "gta/rope には基準 hw が必要です"
            Fq = self.head_dim // 4
            if self.head_dim % 4 != 0:
                Fq += 1

            if not isinstance(hw, list):
                coord = hw
                base = make_SO2mats(coord, Fq).flatten(1, 2)  # [h*w, head_dim/2, 2, 2]
            else:
                coord = make_2dcoord(hw[0], hw[1])
                base = (
                    make_SO2mats(coord, Fq).flatten(2, 3).flatten(0, 1)
                )  # [h*w, head_dim/2, 2, 2]

            # [h*w, head_dim/2, 2, 2]
            base = base[..., : self.head_dim // 2, :, :]

            # gta は q,k,v,o すべて回転、rope は q,k のみ
            self.mat_q = nn.Parameter(base)  # 基準: query 空間
            self.mat_k = nn.Parameter(base)  # 基準: key 空間
            if gta:
                self.mat_v = nn.Parameter(
                    base
                )  # 基準: value 空間（keyと同座標系で使う）
                self.mat_o = nn.Parameter(
                    base.transpose(-2, -1)
                )  # 出力側（query座標系）

    def _infer_hw(self, x):
        if self.weight == "conv":
            h = x.shape[2] // self.stride
            w = x.shape[3] // self.stride
            return (h, w)
        else:
            # fc のときは self.hw を使う（固定長列やパッチ列をHWに見立てる）
            assert self.hw is not None
            return (self.hw[0], self.hw[1])

    def _rearrange_in(self, x, kind: str):
        """
        kind: 'q' or 'kv'
        conv:
          q : (B,C,H,W) -> (B,heads, (H*W), head_dim)
          kv: (B,2C,H,W) -> k,v をそれぞれ (B,heads, (H*W), head_dim)
        fc:
          q : (B,Q,C) -> (B,heads, Q, head_dim)
          kv: (B,K,2C) -> ...
        """
        if self.weight == "conv":
            if kind == "q":
                x = self.W_q(x)
                x = einops.rearrange(x, "b (nh c) h w -> b nh (h w) c", nh=self.heads)
                return x
            else:
                kv = self.W_kv(x)
                k, v = kv.chunk(2, dim=1)  # channel 次元で分割
                k = einops.rearrange(k, "b (nh c) h w -> b nh (h w) c", nh=self.heads)
                v = einops.rearrange(v, "b (nh c) h w -> b nh (h w) c", nh=self.heads)
                return k, v
        else:
            if kind == "q":
                x = self.W_q(x)
                x = einops.rearrange(x, "b k (nh c) -> b nh k c", nh=self.heads)
                return x
            else:
                kv = self.W_kv(x)
                k, v = kv.chunk(2, dim=-1)  # feature 次元で分割
                k = einops.rearrange(k, "b k (nh c) -> b nh k c", nh=self.heads)
                v = einops.rearrange(v, "b k (nh c) -> b nh k c", nh=self.heads)
                return k, v

    def _rearrange_out(self, x, hq, wq):
        if self.weight == "conv":
            x = einops.rearrange(x, "b nh (h w) c -> b (nh c) h w", h=hq, w=wq)
            x = self.W_o(x)
            return x
        else:
            x = einops.rearrange(x, "b nh q c -> b q (nh c)")
            x = self.W_o(x)
            return x

    def _rescale_mat(self, mat, from_hw, to_hw):
        if from_hw == to_hw:
            return mat
        # mat: [(h*w), head_dim/2, 2, 2] を 2D補間でサイズ変換
        f, c, d = mat.shape[1:]
        src_h, src_w = from_hw
        dst_h, dst_w = to_hw
        grid = einops.rearrange(mat, "(h w) f c d -> (f c d) h w", h=src_h, w=src_w)
        grid = F.interpolate(
            grid[None], size=(dst_h, dst_w), mode="bilinear", align_corners=False
        )[0]
        grid = einops.rearrange(grid, "(f c d) h w -> (h w) f c d", f=f, c=c, d=d)
        return grid

    def _apply_posrot(self, x, mat, hw_from, hw_to):
        # x: (B, heads, L, head_dim), mat: [(h*w), head_dim/2, 2, 2]
        # L should be h*w in conv, or Q/K in fc where we "pretend" it's h*w
        scaled = self._rescale_mat(mat, hw_from, hw_to)
        return rep_mul_x(scaled, x)

    def forward(
        self,
        x_q,  # conv: (B,C,H,W) / fc: (B,Q,C)
        x_kv,  # conv: (B,C,H,W) / fc: (B,K,C)
        attn_mask=None,  # 形状は torch.sdpa に準ずる（例: (B*H, Q, K)）
        key_padding_mask=None,  # (B, K) True=マスク
    ):
        B = x_q.size(0)
        # 1) 形状・射影
        hq, wq = self._infer_hw(x_q)
        hk, wk = self._infer_hw(x_kv)

        q = self._rearrange_in(x_q, kind="q")  # (B, H, Q, D)
        k, v = self._rearrange_in(x_kv, kind="kv")  # (B, H, K, D) ×2
        Q = q.size(2)
        K = k.size(2)

        # 2) 回転（必要なら）
        if self.gta or self.rope:
            # 基準行列は self.hw ベース。実行時に query/key/value で各々スケール
            base_hw = self.hw if isinstance(self.hw, (list, tuple)) else self.hw
            if self.rope:
                q = self._apply_posrot(q, self.mat_q, base_hw, (hq, wq))
                k = self._apply_posrot(k, self.mat_k, base_hw, (hk, wk))
            else:  # gta
                q = self._apply_posrot(q, self.mat_q, base_hw, (hq, wq))
                k = self._apply_posrot(k, self.mat_k, base_hw, (hk, wk))
                v = self._apply_posrot(v, self.mat_v, base_hw, (hk, wk))

        # 3) マスク統合（key_padding_mask -> attn_mask へ変換）
        # key_padding_mask: (B,K) True=pad を -inf に
        if key_padding_mask is not None:
            # sdpa 互換: (B, 1, 1, K) にして broadcast、次いで (B*H, Q, K) へ
            pad = key_padding_mask[:, None, None, :].to(q.dtype)  # 1=mask
            pad = pad.masked_fill(pad > 0, float("-inf")).masked_fill(pad == 0, 0.0)
            pad = pad.expand(B, self.heads, Q, K)  # (B,H,Q,K)
            pad = pad.reshape(B * self.heads, Q, K)
            if attn_mask is None:
                attn_mask = pad
            else:
                attn_mask = attn_mask + pad  # どちらも -inf/0 の加算で合成

        # 4) Attention
        x = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask
        )  # (B, H, Q, D)

        # 5) 出力回転（GTAのみ; 出力はquery座標系）
        if self.gta:
            x = self._apply_posrot(x, self.mat_o, base_hw, (hq, wq))

        # 6) 出力整形 + 出力射影
        out = self._rearrange_out(x, hq, wq)  # conv: (B,C,H,W) / fc: (B,Q,C)
        return out


class Step(nn.Module):
    def __init__(self, model, embed_dim):
        super().__init__()
        self.model = model
        self.norm_attn = nn.RMSNorm(embed_dim)

    def forward(self, q, kv):
        V, B, L, D = q.shape
        L_sqrt = int(L**0.5)
        q = self.norm_attn(q)
        kv = self.norm_attn(kv)

        q = q.reshape(V * B, L_sqrt, L_sqrt, D)
        kv = kv.reshape(V * B, L_sqrt, L_sqrt, D)

        z = self.model(
            q.permute(0, 3, 1, 2), kv.permute(0, 3, 1, 2)
        )  # (VB, D, L_sqrt, L_sqrt)
        z = z.reshape(V, B, D, L).permute(0, 1, 3, 2)  # (V, B, L, D)
        z = z.reshape(V, B, L, D)
        return q.reshape(V, B, L, D) + z


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return x + self.fn(x)


class TransformerBlock(nn.Module):
    def __init__(
        self, embed_dim, num_heads, ffn_dim_multiplier, L_sqrt, use_bias, use_cross_attn
    ):
        super().__init__()
        self.attn = CrossAttention(
            ch=embed_dim,
            heads=num_heads,
            weight="conv",
            kernel_size=1,
            stride=1,
            padding=0,
            gta=True,
            rope=False,
            hw=[L_sqrt, L_sqrt],
        )
        self.ffn = nn.Sequential(
            nn.RMSNorm(embed_dim),
            Residual(
                SwiGLU(
                    embed_dim, use_bias=use_bias, ffn_dim_multiplier=ffn_dim_multiplier
                )
            ),
        )
        self._self = Step(self.attn, embed_dim)

    def forward(self, z):
        z = self._self(z, z)
        z = self.ffn(z)
        return z


class Transformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        ffn_dim_multiplier,
        L_sqrt,
        use_bias,
        use_cross_attn,
        num_layers,
    ):
        super().__init__()
        # self.cross_attn = CrossAttention(
        #     ch=embed_dim,
        #     heads=num_heads,
        #     weight="conv",
        #     kernel_size=1,
        #     stride=1,
        #     padding=0,
        #     gta=True,
        #     rope=False,
        #     hw=[L_sqrt, L_sqrt],
        # )
        # self._cross = Step(self.cross_attn, embed_dim)
        self.mix = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim, bias=use_bias),
        )
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    embed_dim,
                    num_heads,
                    ffn_dim_multiplier,
                    L_sqrt,
                    use_bias,
                    use_cross_attn,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, z, x):
        # z = self._cross(x, z)
        # z = x
        z = torch.cat([z, x], dim=-1)  # (V, B, L, 2D)
        z = self.mix(z)  # (V, B, D, L)
        for block in self.blocks:
            z = block(z)
        return z


class Block(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ffn_dim_multiplier: int,
        L_sqrt: int,
        use_bias: bool = False,
        use_cross_attn: bool = True,
        num_rep_attn: int = 4,
    ):
        super().__init__()
        self.num_rep_attn = num_rep_attn

        self.model_self = CrossAttention(
            ch=embed_dim,
            heads=num_heads,
            weight="conv",
            kernel_size=1,
            stride=1,
            padding=0,
            gta=True,
            rope=False,
            hw=[L_sqrt, L_sqrt],
        )
        self._self = Step(self.model_self, embed_dim)
        self.mlp = nn.Sequential(
            nn.RMSNorm(embed_dim),
            Residual(
                SwiGLU(
                    embed_dim, use_bias=use_bias, ffn_dim_multiplier=ffn_dim_multiplier
                )
            ),
        )
        self.use_cross_attn = use_cross_attn
        if use_cross_attn:
            self.model_cross = CrossAttention(
                ch=embed_dim,
                heads=num_heads,
                weight="conv",
                kernel_size=1,
                stride=1,
                padding=0,
                gta=True,
                rope=False,
                hw=[L_sqrt, L_sqrt],
            )
            self._cross = Step(self.model_cross, embed_dim=embed_dim)
        else:
            self.mix = nn.Sequential(
                nn.Linear(embed_dim * 2, embed_dim, bias=use_bias),
            )

    def forward(self, z, x) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.use_cross_attn:
            z = self._cross(z, x)
        else:
            z = torch.cat([z, x], dim=-1)  # (V, B, L, 2D)
            z = self.mix(z)  # (V, B, D, L)
        for _ in range(self.num_rep_attn):
            z = self._self(z, z)
        z = self.mlp(z)
        return z


class Body(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        use_bias: bool = False,
        L_sqrt: int = 30,
        num_vocab: int = 4,
        num_classes: int = 5,
        num_remain_grad: int = 8,
        ffn_dim_multiplier: int = 4,
        use_cross_attn: bool = True,
        num_rep_attn: int = 4,
        num_layers: int = 1,
        use_mpc: bool = False,
        use_transformer: bool = False,
        no_truncation: bool = False,
        confidence_type: Literal["max_prob", "entropy", "log_prob"] = "max_prop",
    ):
        super().__init__()
        self.L = L_sqrt * L_sqrt
        self.num_vocab = num_vocab
        self.num_classes = num_classes
        self.num_remain_grad = num_remain_grad
        self.no_truncation = no_truncation
        self.confidence_type = confidence_type

        # Normalize z before passing to transformer
        self.embedding = nn.Embedding(num_vocab, embed_dim)
        if use_mpc:
            # self.norm = nn.RMSNorm(embed_dim)
            self.noise_level = nn.Parameter(torch.tensor(1.0))

        self.out_proj = nn.Sequential(
            nn.RMSNorm(embed_dim), nn.Linear(embed_dim, num_classes, bias=use_bias)
        )
        self.use_transformer = use_transformer
        if not use_transformer:
            self.iter_sa = nn.ModuleList(
                [
                    Block(
                        embed_dim=embed_dim,
                        num_heads=num_heads,
                        ffn_dim_multiplier=ffn_dim_multiplier,
                        L_sqrt=L_sqrt,
                        use_bias=use_bias,
                        use_cross_attn=use_cross_attn,
                        num_rep_attn=num_rep_attn,
                    )
                    for _ in range(num_layers)
                ]
            )
        else:
            self.iter_sa = Transformer(
                embed_dim=embed_dim,
                num_heads=num_heads,
                ffn_dim_multiplier=ffn_dim_multiplier,
                L_sqrt=L_sqrt,
                use_bias=use_bias,
                use_cross_attn=use_cross_attn,
                num_layers=num_layers,
            )  # type: ignore

    def _loop(
        self, x_emb_v: torch.Tensor, num_iter: int, return_intermediate: bool = False
    ) -> torch.Tensor:
        V, B, L, D = x_emb_v.shape
        z = torch.randn((V, B, L, D), device=x_emb_v.device)
        if return_intermediate:
            zs = []
            for _ in range(num_iter - self.num_remain_grad):
                for layer in self.iter_sa:
                    z = layer(z, x_emb_v)
                zs.append(z)
            if not self.no_truncation:
                z = z.detach()
            for _ in range(self.num_remain_grad):
                for layer in self.iter_sa:
                    z = layer(z, x_emb_v)
                zs.append(z)
            return torch.stack(zs, dim=2)  # (V, B, num_iter, L, D)
        else:
            for _ in range(num_iter - self.num_remain_grad):
                for layer in self.iter_sa:
                    z = layer(z, x_emb_v)
            if not self.no_truncation:
                z = z.detach()
            for _ in range(self.num_remain_grad):
                for layer in self.iter_sa:
                    z = layer(z, x_emb_v)
            return z  # (V, B, L, D)

    def _loop_mpc(
        self,
        x_emb_v: torch.Tensor,
        num_iter: int,
        filled_v: torch.Tensor,
        mpc_every: int = 4,
    ) -> torch.Tensor:
        V, B, L, D = x_emb_v.shape
        z = torch.randn((V, B, L, D), device=x_emb_v.device)

        for i in range(num_iter - self.num_remain_grad):
            for layer in self.iter_sa:
                z = layer(z, x_emb_v)
            if i % mpc_every == mpc_every - 1:
                if V > 1:
                    logits = self.out_proj(z)
                    log_prob = F.log_softmax(logits, dim=-1)  # (V, B, L, C)
                    prob = log_prob.exp()  # (V, B, L, C)
                    entropy = -(prob * log_prob * (~filled_v)).sum(
                        dim=(-2, -1)
                    )  # (V, B)
                    best_v = entropy.argmin(dim=0)  # (B,)
                    z = z.permute(1, 0, 2, 3)[
                        torch.arange(B, device=z.device), best_v
                    ]  # (B, L, D)
                    z = z.unsqueeze(0).expand(V, -1, -1, -1)  # (V, B, L, D)
                z = (
                    z
                    + torch.randn((V, B, L, D), device=x_emb_v.device)
                    * self.noise_level
                )

        z = z.detach()
        for j in range(self.num_remain_grad):
            i = j + (num_iter - self.num_remain_grad)
            for layer in self.iter_sa:
                z = layer(z, x_emb_v)
            if i % mpc_every == mpc_every - 1 and i != num_iter - 1:
                if V > 1:
                    logits = self.out_proj(z)
                    log_prob = F.log_softmax(logits, dim=-1)  # (V, B, L, C)
                    prob = log_prob.exp()  # (V, B, L, C)
                    entropy = -(prob * log_prob * (~filled_v)).sum(
                        dim=(-2, -1)
                    )  # (V, B)
                    best_v = entropy.argmin(dim=0)  # (B,)
                    z = z.permute(1, 0, 2, 3)[
                        torch.arange(B, device=z.device), best_v
                    ]  # (B, L, D)
                    z = z.unsqueeze(0).expand(V, -1, -1, -1)  # (V, B, L, D)
                z = (
                    z
                    + torch.randn((V, B, L, D), device=x_emb_v.device)
                    * self.noise_level
                )

        return z

    def forward(
        self,
        x_tokens: torch.Tensor,
        filled: torch.Tensor,
        num_iter: int,
        num_votes: int = 1,
        use_mpc: bool = False,
        mpc_every: int = 4,
        return_intermediate: bool = False,
    ):
        B, L = x_tokens.size(0), x_tokens.size(1)
        D = self.embedding.weight.size(1)
        V = num_votes

        x_emb = self.embedding(x_tokens)
        x_emb = x_emb.reshape(B, L, D)
        filled_v = filled.unsqueeze(0).unsqueeze(-1).expand(V, -1, -1, -1)  # (V, B, L)

        x_emb_v = x_emb.unsqueeze(0).expand(V, -1, -1, -1)  # (V,B,L,D)
        if self.use_transformer:
            z = torch.randn_like(x_emb_v, device=x_emb_v.device)
            z = self.iter_sa(z, x_emb_v)
        elif use_mpc:
            z = self._loop_mpc(
                x_emb_v, num_iter, filled_v=filled_v, mpc_every=mpc_every
            )
        else:
            z = self._loop(x_emb_v, num_iter, return_intermediate=return_intermediate)
        logits = self.out_proj(z)

        if return_intermediate:
            return logits  # (V, B, num_iter, L, C)

        if self.training:
            return logits[0]

        prob = F.softmax(logits, dim=-1)  # (V, B, L, C)

        match self.confidence_type:
            case "max_prob":
                max_prob_digit = prob.max(dim=-1).values  # (V, B, L)
                certainty = (
                    (max_prob_digit * (~filled_v.squeeze(-1))).sum(-1).max(dim=0).values
                )  # (B, )
                best_v = (
                    (max_prob_digit * (~filled_v.squeeze(-1))).sum(-1).argmax(dim=0)
                )  # (B,)

            case "entropy":
                log_prob = F.log_softmax(logits, dim=-1)  # (V
                prob = log_prob.exp()  # (V, B, L, C)
                entropy = -(prob * log_prob * (~filled_v)).sum(dim=(-2, -1))  # (V, B)
                certainty = (-entropy).max(dim=0).values  # (B, )
                best_v = (-entropy).argmax(dim=0)  # (B,)

            case "log_prob":
                log_prob = F.log_softmax(logits, dim=-1)  # (V, B, L, C)
                log_prob_filled = (log_prob * (~filled_v)).sum(dim=-1)  # (V, B, L)
                log_prob_board = log_prob_filled.sum(dim=-1)  # (V, B)
                certainty = log_prob_board.max(dim=0).values  # (B, )
                best_v = log_prob_board.argmax(dim=0)  # (B,)

            case _:
                raise ValueError(f"Unknown confidence_type: {confidence_type}")

        logits = logits.permute(1, 0, 2, 3)[
            torch.arange(B, device=logits.device), best_v
        ]  # (B, L, C)

        return logits, certainty


class ItrAttention(L.LightningModule):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        use_bias: bool = False,
        betas: Tuple[float, float] = (0.9, 0.95),
        weight_decay: float = 0.01,
        lr: float = 1.2e-3,
        batch_size: int = 128,
        num_workers: int = 4,
        num_proc_preprocessing: int = 20,
        beta: float = 0.995,
        num_iter: int = 1,
        train_dataset_name: str = "sudoku",
        test_dataset_name: str = "sudoku-hard",
        L_sqrt: int = 9,
        num_vocab: int = 10,
        num_classes: int = 9,
        use_compile: bool = False,
        num_iter_test: int = 16,
        num_votes_test: int = 10,
        num_remain_grad: int = 8,
        update_after_step: int = 100,
        update_every: int = 10,
        ffn_dim_multiplier: int = 4,
        use_cross_attn: bool = True,
        num_rep_attn: int = 4,
        num_layers: int = 1,
        use_mpc: bool = False,
        mpc_every: int = 4,
        use_transformer: bool = False,
        no_truncation: bool = False,
        confidence_type: Literal["max_prob", "entropy", "log_prob"] = "max_prob",
    ):
        super().__init__()
        self.save_hyperparameters()
        self.train_dataset_name = train_dataset_name
        self.test_dataset_name = test_dataset_name
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_proc_preprocessing = num_proc_preprocessing
        self.context_length = L_sqrt * L_sqrt
        self.L = L_sqrt * L_sqrt
        self.num_vocab = num_vocab
        self.num_classes = num_classes

        self.num_iter_test = num_iter_test
        self.num_votes_test = num_votes_test
        # -------------------------------------------------------------
        self.model_body = Body(
            num_heads=num_heads,
            embed_dim=embed_dim,
            use_bias=use_bias,
            L_sqrt=L_sqrt,
            num_vocab=num_vocab,
            num_classes=num_classes,
            num_remain_grad=num_remain_grad,
            ffn_dim_multiplier=ffn_dim_multiplier,
            use_cross_attn=use_cross_attn,
            num_rep_attn=num_rep_attn,
            num_layers=num_layers,
            use_mpc=use_mpc,
            use_transformer=use_transformer,
            no_truncation=no_truncation,
            confidence_type=confidence_type,
        )

        if use_compile:
            self.model_body.compile()
        self.model_body_ema = EMA(
            self.model_body,
            beta=beta,
            update_after_step=update_after_step,
            update_every=update_every,
        )

        self.num_iter = num_iter
        self.use_mpc = use_mpc
        self.mpc_every = mpc_every
        self.use_transformer = use_transformer

        # Optim params
        self.betas = betas
        self.weight_decay = weight_decay
        self.lr = lr

        self.val_results = {}
        self.test_results = {}

    def training_step(self, batch, batch_idx):
        x, y, filled, _ = batch
        target = y
        B = x.shape[0]
        x = x.reshape(B, self.L).long()
        target = y.reshape(B, self.L).long()
        filled = filled.reshape(B, self.L)

        logits = self.model_body(
            x,
            filled,
            num_iter=self.num_iter,
            num_votes=1,
            use_mpc=self.use_mpc,
            mpc_every=self.mpc_every,
        )

        loss = F.cross_entropy(
            logits.reshape(-1, self.num_classes),
            target.reshape(-1),
        )

        digit_acc = digit_accuracy(logits.detach(), target)
        board_acc = board_accuracy(logits, x, target, filled)

        self.log_dict(
            {
                "train_loss": loss,
                "train_digit_acc": digit_acc,
                "train_board_acc": board_acc,
            },
            prog_bar=True,
            on_step=True,
            sync_dist=True,
        )

        return loss

    def on_train_batch_end(self, outputs, batch, batch_idx):
        self.model_body_ema.update()

    def on_validation_epoch_start(self):
        self.val_results = {}

    def on_test_epoch_start(self):
        self.test_results = {}

    def validation_step(self, batch, batch_idx):
        x, y, filled, group_id = batch
        B = x.shape[0]
        x = x.reshape(B, self.L).long()
        target = y.reshape(B, self.L).long()
        filled = filled.reshape(B, self.L)
        logits, confidence = self.model_body_ema(
            x,
            filled,
            num_iter=self.num_iter,
            num_votes=1,
            use_mpc=self.use_mpc,
            mpc_every=self.mpc_every,
        )
        group_id = group_id.reshape(B)

        for i, (gid, conf) in enumerate(
            zip(group_id.cpu().numpy(), confidence.cpu().numpy())
        ):
            if gid not in self.val_results:
                self.val_results[gid] = {
                    "confidence": conf,
                    "logits": logits[i],  # (L, num_classes)
                    "target": target[i],  # (L,)
                    "x": x[i],  # (L,)
                    "filled": filled[i],  # (L,)
                }
            elif conf > self.val_results[gid]["confidence"]:
                self.val_results[gid] = {
                    "confidence": conf,
                    "logits": logits[i],
                    "target": target[i],
                    "x": x[i],
                    "filled": filled[i],
                }

        return None

    def test_step(self, batch, batch_idx):
        x, y, filled, group_id = batch
        B = x.shape[0]
        x = x.reshape(B, self.L).long()
        target = y.reshape(B, self.L).long()
        filled = filled.reshape(B, self.L)

        if (
            not self.use_mpc
            and not self.use_transformer
            and batch_idx == 0
            and self.local_rank == 0
            and self.num_votes_test == 32
            and False
        ):
            z = self.model_body_ema(
                x,
                filled,
                num_iter=self.num_iter_test,
                num_votes=self.num_votes_test,
                use_mpc=self.use_mpc,
                mpc_every=self.mpc_every,
                return_intermediate=True,
            )  # (V, B, num_iter, L, C)
            with open(f"outputs/{self.test_dataset_name}_intermediate_z.pt", "wb") as f:
                torch.save(z, f)
            with open(f"outputs/{self.test_dataset_name}_x.pt", "wb") as f:
                torch.save(x, f)
            with open(f"outputs/{self.test_dataset_name}_filled.pt", "wb") as f:
                torch.save(filled, f)
            with open(f"outputs/{self.test_dataset_name}_target.pt", "wb") as f:
                torch.save(target, f)

        logits, confidence = self.model_body_ema(
            x,
            filled,
            num_iter=self.num_iter_test,
            num_votes=self.num_votes_test,
            use_mpc=self.use_mpc,
            mpc_every=self.mpc_every,
        )
        group_id = group_id.reshape(B)

        for i, (gid, conf) in enumerate(
            zip(group_id.cpu().numpy(), confidence.cpu().numpy())
        ):
            if gid not in self.test_results:
                self.test_results[gid] = {
                    "confidence": conf,
                    "logits": logits[i],  # (L, num_classes)
                    "target": target[i],  # (L,)
                    "x": x[i],  # (L,)
                    "filled": filled[i],  # (L,)
                }
            elif conf > self.test_results[gid]["confidence"]:
                self.test_results[gid] = {
                    "confidence": conf,
                    "logits": logits[i],
                    "target": target[i],
                    "x": x[i],
                    "filled": filled[i],
                }
        if batch_idx % 100 == 0:
            logits_all = []
            targets_all = []
            x_all = []
            filled_all = []
            for k, v in self.test_results.items():
                x_all.append(v["x"])
                logits_all.append(v["logits"])
                targets_all.append(v["target"])
                filled_all.append(v["filled"])
            logits_all = torch.stack(logits_all, dim=0)  # (N, L, num_classes)
            targets_all = torch.stack(targets_all, dim=0)  # (N, L)
            x_all = torch.stack(x_all, dim=0)  # (N, L)
            filled_all = torch.stack(filled_all, dim=0)  # (N, L)

            loss = F.cross_entropy(
                logits_all.reshape(-1, self.num_classes), targets_all.reshape(-1)
            )
            digit_acc = digit_accuracy(logits_all, targets_all)
            board_acc = board_accuracy(logits_all, x_all, targets_all, filled_all)
            self.log_dict(
                {
                    f"test_loss_{self.num_iter_test}_{self.num_votes_test}_step": loss,
                    f"test_digit_acc_{self.num_iter_test}_{self.num_votes_test}_step": digit_acc,
                    f"test_board_acc_{self.num_iter_test}_{self.num_votes_test}_step": board_acc,
                },
                prog_bar=True,
                on_epoch=False,
                on_step=True,
                sync_dist=True,
            )

        return None

    @torch.no_grad()
    def on_validation_epoch_end(self):
        world_size = self.trainer.world_size if self.trainer else 1

        if world_size == 1:
            logits_all = []
            targets_all = []
            x_all = []
            filled_all = []
            for k, v in self.val_results.items():
                x_all.append(v["x"])
                logits_all.append(v["logits"])
                targets_all.append(v["target"])
                filled_all.append(v["filled"])
            logits_all = torch.stack(logits_all, dim=0)  # (N, L, num_classes)
            targets_all = torch.stack(targets_all, dim=0)  # (N, L)
            x_all = torch.stack(x_all, dim=0)  # (N, L)
            filled_all = torch.stack(filled_all, dim=0)  # (N, L)

            loss = F.cross_entropy(
                logits_all.reshape(-1, self.num_classes), targets_all.reshape(-1)
            )
            digit_acc = digit_accuracy(logits_all, targets_all)
            board_acc = board_accuracy(logits_all, x_all, targets_all, filled_all)
            self.log_dict(
                {
                    "val_loss": loss,
                    "val_digit_acc": digit_acc,
                    "val_board_acc": board_acc,
                },
                prog_bar=True,
                on_epoch=True,
                on_step=False,
                sync_dist=True,
                rank_zero_only=True,
            )
            return None
        keys = [int(k) for k in self.val_results.keys()]
        gids_local = torch.tensor(keys, dtype=torch.long, device=self.device)  # (N, )
        conf_local = torch.tensor(
            [v["confidence"] for v in self.val_results.values()]
        )  # (N, )
        logits_local = torch.stack(
            [v["logits"] for v in self.val_results.values()]
        )  # (N, L, num_classes)
        target_local = torch.stack(
            [v["target"] for v in self.val_results.values()]
        )  # (N, L)
        x_local = torch.stack([v["x"] for v in self.val_results.values()])  # (N, L)
        filled_local = torch.stack(
            [v["filled"] for v in self.val_results.values()]
        )  # (N, L)

        lengths = torch.tensor([gids_local.numel()], device=self.device)
        lengths_all = self.all_gather(lengths).to(torch.long)
        max_len = int(lengths_all.max().item())

        gids_pad = torch.concat(
            [
                gids_local,
                torch.full(
                    (max_len - gids_local.shape[0],), -1, device=gids_local.device
                ),
            ],
            dim=0,
        )

        conf_pad = torch.concat(
            [
                conf_local,
                torch.full(
                    (max_len - conf_local.shape[0],),
                    float("nan"),
                    device=conf_local.device,
                ),
            ],
            dim=0,
        )

        logits_pad = torch.concat(
            [
                logits_local,
                torch.full(
                    (max_len - logits_local.shape[0], self.L, self.num_classes),
                    float("nan"),
                    device=logits_local.device,
                ),
            ],
            dim=0,
        )
        target_pad = torch.concat(
            [
                target_local,
                torch.full(
                    (max_len - target_local.shape[0], self.L),
                    -1,
                    device=target_local.device,
                ),
            ],
            dim=0,
        )
        x_pad = torch.concat(
            [
                x_local,
                torch.full(
                    (max_len - x_local.shape[0],), float("nan"), device=x_local.device
                ),
            ],
            dim=0,
        )
        filled_pad = torch.concat(
            [
                filled_local,
                torch.full(
                    (max_len - filled_local.shape[0], self.L),
                    False,
                    device=filled_local.device,
                ),
            ],
            dim=0,
        )

        gids_all = self.all_gather(gids_pad).reshape(
            world_size * max_len
        )  # (world * max_len,)
        conf_all = self.all_gather(conf_pad).reshape(
            world_size * max_len
        )  # (world * max_len)
        logits_all = self.all_gather(logits_pad).reshape(
            world_size * max_len, self.L, self.num_classes
        )  # (world * max_len, L,num_classes)
        target_all = self.all_gather(target_pad).reshape(
            world_size * max_len, self.L
        )  # (world * max_len,L)
        x_all = self.all_gather(x_pad).reshape(
            world_size * max_len, self.L
        )  # (world * max_len,L)
        filled_all = self.all_gather(filled_pad).reshape(
            world_size * max_len, self.L
        )  # (world * max_len,L)

        res = {}
        for i, gid in enumerate(gids_all):
            if gid not in res and not torch.isnan(conf_all[i]):
                res[gid] = {
                    "confidence": conf_all[i],
                    "logits": logits_all[i],
                    "target": target_all[i],
                    "x": x_all[i],
                    "filled": filled_all[i],
                }
            elif (
                not torch.isnan(res[gid]["confidence"])
                and res[gid]["confidence"] < conf_all[i]
            ):
                res[gid]["confidence"] = conf_all[i]
                res[gid]["logits"] = logits_all[i]
                res[gid]["target"] = target_all[i]
                res[gid]["x"] = x_all[i]
                res[gid]["filled"] = filled_all[i]

        logits_all = []
        targets_all = []
        x_all = []
        filled_all = []
        for k, v in res.items():
            x_all.append(v["x"])
            logits_all.append(v["logits"])
            targets_all.append(v["target"])
            filled_all.append(v["filled"])
        logits_all = torch.stack(logits_all, dim=0)  # (N, L, num_classes)
        targets_all = torch.stack(targets_all, dim=0)  # (N, L)
        x_all = torch.stack(x_all, dim=0)  # (N, L)
        filled_all = torch.stack(filled_all, dim=0)  # (N, L)

        loss = F.cross_entropy(
            logits_all.reshape(-1, self.num_classes), targets_all.reshape(-1)
        )
        digit_acc = digit_accuracy(logits_all, targets_all)
        board_acc = board_accuracy(logits_all, x_all, targets_all, filled_all)
        self.log_dict(
            {
                "val_loss": loss,
                "val_digit": digit_acc,
                "val_board": board_acc,
            },
            on_epoch=True,
            on_step=False,
            sync_dist=True,
            rank_zero_only=True,
        )
        return None

    @torch.no_grad()
    def on_test_epoch_end(self):
        world_size = self.trainer.world_size if self.trainer else 1

        if world_size == 1:
            logits_all = []
            targets_all = []
            x_all = []
            filled_all = []
            for k, v in self.test_results.items():
                x_all.append(v["x"])
                logits_all.append(v["logits"])
                targets_all.append(v["target"])
                filled_all.append(v["filled"])
            filled_all.append(v["filled"])
            logits_all = torch.stack(logits_all, dim=0)  # (N, L, num_classes)
            targets_all = torch.stack(targets_all, dim=0)  # (N, L)
            x_all = torch.stack(x_all, dim=0)  # (N, L)
            filled_all = torch.stack(filled_all, dim=0)  # (N, L)

            loss = F.cross_entropy(
                logits_all.reshape(-1, self.num_classes), targets_all.reshape(-1)
            )
            digit_acc = digit_accuracy(logits_all, targets_all)
            board_acc = board_accuracy(logits_all, x_all, targets_all, filled_all)
            self.log_dict(
                {
                    f"test_loss_{self.num_iter_test}_{self.num_votes_test}_epoch": loss,
                    f"test_digit_acc_{self.num_iter_test}_{self.num_votes_test}_epoch": digit_acc,
                    f"test_board_acc_{self.num_iter_test}_{self.num_votes_test}_epoch": board_acc,
                },
                prog_bar=True,
                on_epoch=True,
                on_step=False,
                sync_dist=True,
                rank_zero_only=True,
            )
            return None

        gids_local = torch.tensor(
            list(self.test_results.keys()), dtype=torch.long, device=self.device
        )  # (N, )
        conf_local = torch.tensor(
            [v["confidence"] for v in self.test_results.values()]
        )  # (N, )
        logits_local = torch.stack(
            [v["logits"] for v in self.test_results.values()]
        )  # (N, L, num_classes)
        target_local = torch.stack(
            [v["target"] for v in self.test_results.values()]
        )  # (N, L)
        x_local = torch.stack([v["x"] for v in self.test_results.values()])  # (N, L)
        filled_local = torch.stack(
            [v["filled"] for v in self.test_results.values()]
        )  # (N, L)

        lengths = torch.tensor([gids_local.numel()], device=self.device)
        lengths_all = self.all_gather(lengths).to(torch.long)
        max_len = int(lengths_all.max().item())

        gids_pad = torch.concat(
            [
                gids_local,
                torch.full(
                    (max_len - gids_local.shape[0],), -1, device=gids_local.device
                ),
            ],
            dim=0,
        )

        conf_pad = torch.concat(
            [
                conf_local,
                torch.full(
                    (max_len - conf_local.shape[0],),
                    float("nan"),
                    device=conf_local.device,
                ),
            ],
            dim=0,
        )

        logits_pad = torch.concat(
            [
                logits_local,
                torch.full(
                    (max_len - logits_local.shape[0], self.L, self.num_classes),
                    float("nan"),
                    device=logits_local.device,
                ),
            ],
            dim=0,
        )
        target_pad = torch.concat(
            [
                target_local,
                torch.full(
                    (max_len - target_local.shape[0], self.L),
                    -1,
                    device=target_local.device,
                ),
            ],
            dim=0,
        )
        x_pad = torch.concat(
            [
                x_local,
                torch.full(
                    (max_len - x_local.shape[0], self.L),
                    float("nan"),
                    device=x_local.device,
                ),
            ],
            dim=0,
        )
        filled_pad = torch.concat(
            [
                filled_local,
                torch.full(
                    (max_len - filled_local.shape[0], self.L),
                    False,
                    device=filled_local.device,
                ),
            ],
            dim=0,
        )

        gids_all = self.all_gather(gids_pad).reshape(
            world_size * max_len
        )  # (world * max_len,)
        conf_all = self.all_gather(conf_pad).reshape(
            world_size * max_len
        )  # (world * max_len)
        logits_all = self.all_gather(logits_pad).reshape(
            world_size * max_len, self.L, self.num_classes
        )  # (world * max_len, L, num_classes)
        target_all = self.all_gather(target_pad).reshape(
            world_size * max_len, self.L
        )  # (world * max_len, L)
        x_all = self.all_gather(x_pad).reshape(
            world_size * max_len, self.L
        )  # (world * max_len, L)
        filled_all = self.all_gather(filled_pad).reshape(
            world_size * max_len, self.L
        )  # (world * max_len, L)

        res = {}
        for i, gid in enumerate(gids_all):
            if gid not in res and not torch.isnan(conf_all[i]):
                res[gid] = {
                    "confidence": conf_all[i],
                    "logits": logits_all[i],
                    "target": target_all[i],
                    "x": x_all[i],
                    "filled": filled_all[i],
                }
            elif (
                not torch.isnan(res[gid]["confidence"])
                and res[gid]["confidence"] < conf_all[i]
            ):
                res[gid]["confidence"] = conf_all[i]
                res[gid]["logits"] = logits_all[i]
                res[gid]["target"] = target_all[i]
                res[gid]["x"] = x_all[i]
                res[gid]["filled"] = filled_all[i]
        logits_all = []
        targets_all = []
        x_all = []
        filled_all = []
        for k, v in res.items():
            x_all.append(v["x"])
            logits_all.append(v["logits"])
            targets_all.append(v["target"])
            filled_all.append(v["filled"])
        logits_all = torch.stack(logits_all, dim=0)  # (N, L, num_classes)
        targets_all = torch.stack(targets_all, dim=0)  # (N, L)
        x_all = torch.stack(x_all, dim=0)  # (N, L)
        filled_all = torch.stack(filled_all, dim=0)  # (N, L)

        loss = F.cross_entropy(
            logits_all.reshape(-1, self.num_classes), targets_all.reshape(-1)
        )
        digit_acc = digit_accuracy(logits_all, targets_all)
        board_acc = board_accuracy(logits_all, x_all, targets_all, filled_all)

        prob = F.softmax(logits_all, dim=-1)  # (N, L, C)
        targets_all[filled_all] = -100
        mce = MulticlassCalibrationError(
            num_classes=self.num_classes, norm="l1", ignore_index=-100
        )
        mce_val = mce(prob.reshape(-1, self.num_classes), targets_all.reshape(-1))

        self.log_dict(
            {
                f"test_loss_{self.num_iter_test}_{self.num_votes_test}_epoch": loss,
                f"test_board_acc_{self.num_iter_test}_{self.num_votes_test}_epoch": board_acc,
                f"test_digit_acc_{self.num_iter_test}_{self.num_votes_test}_epoch": digit_acc,
                f"test_mce_{self.num_iter_test}_{self.num_votes_test}_epoch": mce_val,
            },
            on_epoch=True,
            on_step=False,
            sync_dist=True,
            rank_zero_only=True,
        )
        return None

    # ------------------------------------------------------------------

    def configure_optimizers(self):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": self.weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(
            f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
        )
        print(
            f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
        )

        optimizer = torch.optim.AdamW(optim_groups, betas=self.betas, lr=self.lr)

        return optimizer

    # ---------------------------- Data ---------------------------------

    def prepare_data(self):
        if self.train_dataset_name == "sudoku-extreme":
            ExtremeSudokuAugDataset(dataset_dir="./data/", split="train")
            ExtremeSudokuAugDataset(dataset_dir="./data/", split="test")
        elif self.train_dataset_name == "sudoku":
            SudokuDataset(dataset_dir="./data/satnet", split="train")
            SudokuDataset(dataset_dir="./data/satnet", split="validation")
        elif self.train_dataset_name == "maze":
            MazeDataset(dataset_dir="./data", split="train")
            MazeDataset(dataset_dir="./data", split="test")
        elif self.train_dataset_name == "arc":
            ARCDataset(dataset_dir="./data", split="train")
            ARCDataset(dataset_dir="./data", split="test")
        elif self.train_dataset_name == "arc2":
            ARC2Dataset(dataset_dir="./data", split="train")
            ARC2Dataset(dataset_dir="./data", split="test")

        if self.test_dataset_name == "sudoku-extreme":
            ExtremeSudokuAugDataset(dataset_dir="./data/", split="test")
        elif self.test_dataset_name == "sudoku":
            SudokuDataset(dataset_dir="./data/satnet", split="test")
        elif self.test_dataset_name == "sudoku-hard":
            HardSudokuDataset(dataset_dir="./data/rrn", split="test")
        elif self.test_dataset_name == "maze":
            MazeDataset(dataset_dir="./data", split="test")
        elif self.test_dataset_name == "arc":
            ARCDataset(dataset_dir="./data", split="test")
        elif self.test_dataset_name == "arc2":
            ARC2Dataset(dataset_dir="./data", split="test")

    def setup(self, stage: str):
        if self.train_dataset_name == "sudoku-extreme":
            print("Using extreme train dataset")
            self.train_dataset = ExtremeSudokuAugDataset(
                dataset_dir="./data/", split="train"
            )

            self.val_dataset = ExtremeSudokuAugDataset(
                dataset_dir="./data/", split="test"
            )

        elif self.train_dataset_name == "sudoku":
            self.train_dataset = SudokuDataset(
                dataset_dir="./data/satnet", split="train"
            )
            self.val_dataset = SudokuDataset(dataset_dir="./data/satnet", split="valid")

        elif self.train_dataset_name == "maze":
            self.train_dataset = MazeDataset(dataset_dir="./data", split="train")
            self.val_dataset = MazeDataset(dataset_dir="./data", split="test")
        elif self.train_dataset_name == "arc":
            self.train_dataset = ARCDataset(dataset_dir="./data", split="train")
            self.val_dataset = ARCDataset(dataset_dir="./data", split="test")
        elif self.train_dataset_name == "arc2":
            self.train_dataset = ARC2Dataset(dataset_dir="./data", split="train")
            self.val_dataset = ARC2Dataset(dataset_dir="./data", split="test")

        if self.test_dataset_name == "sudoku-extreme":
            print("Using extreme test dataset")
            self.test_dataset = ExtremeSudokuAugDataset(
                dataset_dir="./data/", split="test"
            )
        elif self.test_dataset_name == "sudoku-hard":
            self.test_dataset = HardSudokuDataset(
                dataset_dir="./data/rrn",
                split="test",
            )
        elif self.test_dataset_name == "maze":
            self.test_dataset = MazeDataset(dataset_dir="./data/", split="test")
        elif self.test_dataset_name == "arc":
            self.test_dataset = ARCDataset(dataset_dir="./data", split="test")
        elif self.test_dataset_name == "arc2":
            self.test_dataset = ARC2Dataset(dataset_dir="./data", split="test")

        print(
            f"Train dataset size: {len(self.train_dataset)}"
            f", Val dataset size: {len(self.val_dataset)}"
        )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            drop_last=True,
        )

    def val_dataloader(self):  # -> DataLoader[Any]:
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
        )

    def test_dataloader(self):  # -> DataLoader[Any]:
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
        )
