from functools import partial
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
HyperConnections = None
ManifoldConstrainedHyperConnections = None
MHCLite = None
KromHC = None
HCFix = None
MHCDiffBeta = None
def _get_hyper_connections(hc_type="HC"):
    global HyperConnections, ManifoldConstrainedHyperConnections, MHCLite, KromHC, HCFix, MHCDiffBeta
    if hc_type == "mHC":
        if ManifoldConstrainedHyperConnections is None:
            try:
                from hyper_connections.mHC_fix import ManifoldConstrainedHyperConnections as mHC
                ManifoldConstrainedHyperConnections = mHC
            except ImportError:
                raise ImportError(
                    "hyper_connections.mHC_fix module not found. Ensure the hyper-connections-main\n"
                    "directory is in your PYTHONPATH with the mHC_fix.py file."
                )
        return ManifoldConstrainedHyperConnections
    elif hc_type == "mHC-lite":
        if MHCLite is None:
            try:
                from hyper_connections.mhc_lite import MHCLite as MHCLiteClass
                MHCLite = MHCLiteClass
            except ImportError:
                raise ImportError(
                    "hyper_connections.mhc_lite module not found. Ensure the hyper-connections-main\n"
                    "directory is in your PYTHONPATH with the mhc_lite.py file."
                )
        return MHCLite
    elif hc_type == "KromHC":
        if KromHC is None:
            try:
                from hyper_connections.Kromhc_optimized_SharedAlpha import KromHC as KromHCClass
                KromHC = KromHCClass
            except ImportError:
                raise ImportError(
                    "hyper_connections.Kromhc module not found. Ensure the hyper-connections-main\n"
                    "directory is in your PYTHONPATH with the Kromhc.py file."
                )
        return KromHC
    else:
        if HyperConnections is None:
            try:
                from hyper_connections import HyperConnections as HC
                HyperConnections = HC
            except ImportError:
                raise ImportError(
                    "hyper_connections package not found. Install it with:\n"
                    "  pip install hyper-connections\n"
                    "Or add the hyper-connections-main directory to your PYTHONPATH."
                )
        return HyperConnections
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0"
try:
    from flash_attn import flash_attn_func, flash_attn_with_kvcache
    USE_FA2 = True
    print0("Using Flash Attention 2")
except ImportError:
    print0("Flash Attention 2 not found. Falling back to PyTorch SDPA.")
    USE_FA2 = False
@dataclass
class GPTConfig:
    sequence_len: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 6
    n_kv_head: int = 6
    n_embd: int = 768
    window_pattern: str = "L"
    num_residual_streams: int = 1
    hc_type: str = "HC"
def norm(x):
    return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3)
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
    def forward(self, x, cos_sin, window_size, kv_cache):
        B, T, C = x.size()
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        q, k = norm(q), norm(k)
        if kv_cache is None:
            if USE_FA2:
                y = flash_attn_func(q, k, v, causal=True, window_size=window_size)
            else:
                q = q.transpose(1, 2)
                k = k.transpose(1, 2)
                v = v.transpose(1, 2)
                y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
                y = y.transpose(1, 2)
        else:
            if USE_FA2:
                k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
                if k_cache is None or v_cache is None:
                    raise ValueError(f"Layer {self.layer_idx} has no cache")
                y = flash_attn_with_kvcache(
                    q, k_cache, v_cache,
                    k=k, v=v,
                    cache_seqlens=kv_cache.cache_seqlens,
                    causal=True,
                    window_size=window_size,
                )
            else:
                k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
                start_pos = kv_cache.get_pos()
                k_cache[:, start_pos : start_pos + T, :, :] = k
                v_cache[:, start_pos : start_pos + T, :, :] = v
                q_sdpa = q.transpose(1, 2)
                k_full = k_cache[:, : start_pos + T, :, :].transpose(1, 2)
                v_full = v_cache[:, : start_pos + T, :, :].transpose(1, 2)
                y = F.scaled_dot_product_attention(q_sdpa, k_full, v_full, is_causal=True)
                y = y.transpose(1, 2)
            if self.layer_idx == kv_cache.n_layers - 1:
                kv_cache.advance(T)
        y = y.contiguous().view(B, T, -1)
        y = self.c_proj(y)
        return y
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x
class NormedAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
    def forward(self, x, cos_sin, window_size, kv_cache):
        return self.attn(norm(x), cos_sin, window_size, kv_cache)
class NormedMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = MLP(config)
    def forward(self, x):
        return self.mlp(norm(x))
class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)
    def forward(self, x, cos_sin, window_size, kv_cache):
        x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
        x = x + self.mlp(norm(x))
        return x
class GPT(nn.Module):
    def __init__(self, config, pad_vocab_size_to=64):
        super().__init__()
        self.config = config
        self.window_sizes = self._compute_window_sizes(config)
        padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
        if padded_vocab_size != config.vocab_size:
            print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
        self.use_hyper_connections = config.num_residual_streams > 1
        num_streams = config.num_residual_streams
        self.hc_type = config.hc_type if hasattr(config, 'hc_type') else "HC"
        if self.use_hyper_connections:
            HC = _get_hyper_connections(self.hc_type)
            init_hyper_conn, self.expand_streams, self.reduce_streams = \
                HC.get_init_and_expand_reduce_stream_functions(num_streams)
            layers = nn.ModuleList()
            sublayer_idx = 0
            for layer_idx in range(config.n_layer):
                attn_branch = NormedAttention(config, layer_idx)
                mlp_branch = NormedMLP(config)
                layers.append(nn.ModuleDict({
                    "attn": init_hyper_conn(dim=config.n_embd, branch=attn_branch, layer_index=sublayer_idx),
                    "mlp": init_hyper_conn(dim=config.n_embd, branch=mlp_branch, layer_index=sublayer_idx + 1),
                }))
                sublayer_idx += 2
            self.transformer = nn.ModuleDict({
                "wte": nn.Embedding(padded_vocab_size, config.n_embd),
                "h": layers,
            })
        else:
            self.expand_streams = None
            self.reduce_streams = None
            self.transformer = nn.ModuleDict({
                "wte": nn.Embedding(padded_vocab_size, config.n_embd),
                "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
            })
        self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
        if not self.use_hyper_connections:
            self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
            self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
        else:
            self.resid_lambdas = None
            self.x0_lambdas = None
        self.rotary_seq_len = config.sequence_len * 10
        head_dim = config.n_embd // config.n_head
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
    def init_weights(self):
        torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
        torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
        n_embd = self.config.n_embd
        s = 3**0.5 * n_embd**-0.5
        if self.use_hyper_connections:
            n = self.config.num_residual_streams
            for layer_idx, layer in enumerate(self.transformer.h):
                attn = layer["attn"].branch.attn
                torch.nn.init.uniform_(attn.c_q.weight, -s, s)
                torch.nn.init.uniform_(attn.c_k.weight, -s, s)
                torch.nn.init.uniform_(attn.c_v.weight, -s, s)
                torch.nn.init.zeros_(attn.c_proj.weight)
                mlp = layer["mlp"].branch.mlp
                torch.nn.init.uniform_(mlp.c_fc.weight, -s, s)
                torch.nn.init.zeros_(mlp.c_proj.weight)
                sublayer_idx = layer_idx * 2
                for sublayer_offset, hc_module in enumerate([layer["attn"], layer["mlp"]]):
                    with torch.no_grad():
                        init_residual_index = (sublayer_idx + sublayer_offset) % n
                        if self.hc_type == "mHC":
                            num_fracs = getattr(hc_module, 'num_fracs', 1)
                            num_input_views = getattr(hc_module, 'num_input_views', 1)
                            init_alpha0 = torch.ones(n * num_fracs, num_input_views * num_fracs,
                                                    device=hc_module.static_alpha.device) * -1
                            init_alpha0[init_residual_index, :] = 1.0
                            init_alpha1 = torch.ones(n * num_fracs, n * num_fracs,
                                                    device=hc_module.static_alpha.device) * -8
                            init_alpha1.fill_diagonal_(0.0)
                            hc_module.static_alpha.copy_(torch.cat([init_alpha0, init_alpha1], dim=1))
                            beta_init = torch.ones(n * num_fracs, device=hc_module.static_beta.device) * -1.0
                            beta_init[init_residual_index] = 1.0
                            hc_module.static_beta.copy_(beta_init)
                        elif self.hc_type == "mHC-lite":
                            import math
                            num_perms = math.factorial(n)
                            num_fracs = getattr(hc_module, 'num_fracs', 1)
                            num_input_views = getattr(hc_module, 'num_input_views', 1)
                            init_alpha0 = torch.ones(n * num_fracs, num_input_views * num_fracs,
                                                    device=hc_module.static_alpha.device) * -1
                            init_alpha0[init_residual_index, :] = 1.0
                            init_alpha1 = torch.ones(num_perms * num_fracs,
                                                    device=hc_module.static_alpha.device) * -8
                            init_alpha1[0] = 0.0
                            hc_module.static_alpha.copy_(torch.cat([init_alpha0.view(-1), init_alpha1]))
                            beta_init = torch.ones(n * num_fracs, device=hc_module.static_beta.device) * -1.0
                            beta_init[init_residual_index] = 1.0
                            hc_module.static_beta.copy_(beta_init)
                        elif self.hc_type == "KromHC":
                            num_fracs = getattr(hc_module, 'num_fracs', 1)
                            num_input_views = getattr(hc_module, 'num_input_views', 1)
                            total_res_coeffs = getattr(hc_module, 'total_res_coeffs', 4)
                            factor_perms = getattr(hc_module, 'factor_perms', [2, 2])
                            init_alpha0 = torch.ones(n * num_fracs, num_input_views * num_fracs,
                                                    device=hc_module.static_alpha.device) * -1
                            init_alpha0[init_residual_index, :] = 1.0
                            init_alpha1 = torch.ones(total_res_coeffs * num_fracs,
                                                    device=hc_module.static_alpha.device) * -8
                            coeff_idx = 0
                            for num_perms in factor_perms:
                                init_alpha1[coeff_idx] = 0.0
                                coeff_idx += num_perms
                            hc_module.static_alpha.copy_(torch.cat([init_alpha0.view(-1), init_alpha1]))
                            beta_init = torch.ones(n * num_fracs, device=hc_module.static_beta.device) * -1.0
                            beta_init[init_residual_index] = 1.0
                            hc_module.static_beta.copy_(beta_init)
                        else:
                            init_alpha0 = torch.zeros(n, 1, device=hc_module.static_alpha.device)
                            init_alpha0[init_residual_index, 0] = 1.0
                            eye_n = torch.eye(n, device=hc_module.static_alpha.device)
                            hc_module.static_alpha.copy_(torch.cat([init_alpha0, eye_n], dim=1))
                            hc_module.static_beta.fill_(1.0)
                        hc_module.dynamic_alpha_fn.zero_()
                        if hasattr(hc_module, 'dynamic_alpha_scale'):
                            hc_module.dynamic_alpha_scale.fill_(0.01)
                        else:
                            hc_module.pre_branch_scale.fill_(0.01)
                            if hasattr(hc_module, 'residual_scales'):
                                hc_module.residual_scales.fill_(0.01)
                            else:
                                hc_module.residual_scale.fill_(0.01)
                        hc_module.dynamic_beta_fn.zero_()
                        if hasattr(hc_module, 'dynamic_beta_scale'):
                            hc_module.dynamic_beta_scale.fill_(0.01)
                        else:
                            hc_module.h_post_scale.fill_(0.01)
                        hc_module.norm.gamma.zero_()
        else:
            for block in self.transformer.h:
                torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
                torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
                torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
                torch.nn.init.zeros_(block.attn.c_proj.weight)
                torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
                torch.nn.init.zeros_(block.mlp.c_proj.weight)
            with torch.no_grad():
                self.resid_lambdas.fill_(1.0)
                self.x0_lambdas.fill_(0.0)
        head_dim = self.config.n_embd // self.config.n_head
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.cos, self.sin = cos, sin
        if self.transformer.wte.weight.device.type == "cuda":
            self.transformer.wte.to(dtype=torch.bfloat16)
    def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
        if device is None:
            device = self.transformer.wte.weight.device
        channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
        inv_freq = 1.0 / (base ** (channel_range / head_dim))
        t = torch.arange(seq_len, dtype=torch.float32, device=device)
        freqs = torch.outer(t, inv_freq)
        cos, sin = freqs.cos(), freqs.sin()
        cos, sin = cos.bfloat16(), sin.bfloat16()
        cos, sin = cos[None, :, None, :], sin[None, :, None, :]
        return cos, sin
    def _compute_window_sizes(self, config):
        pattern = config.window_pattern.upper()
        assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
        long_window = config.sequence_len
        short_window = long_window // 2
        char_to_window = {
            "L": (long_window, 0),
            "S": (short_window, 0),
        }
        window_sizes = []
        for layer_idx in range(config.n_layer):
            char = pattern[layer_idx % len(pattern)]
            window_sizes.append(char_to_window[char])
        window_sizes[-1] = (long_window, 0)
        return window_sizes
    def get_device(self):
        return self.transformer.wte.weight.device
    def estimate_flops(self):
        nparams = sum(p.numel() for p in self.parameters())
        nparams_exclude = self.transformer.wte.weight.numel()
        if self.resid_lambdas is not None:
            nparams_exclude += self.resid_lambdas.numel()
        if self.x0_lambdas is not None:
            nparams_exclude += self.x0_lambdas.numel()
        if self.use_hyper_connections:
            for layer in self.transformer.h:
                for name, param in layer["attn"].named_parameters():
                    if not name.startswith("branch."):
                        nparams_exclude += param.numel()
                for name, param in layer["mlp"].named_parameters():
                    if not name.startswith("branch."):
                        nparams_exclude += param.numel()
        h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
        attn_flops = 0
        for window_size in self.window_sizes:
            window = window_size[0]
            effective_seq = t if window < 0 else min(window, t)
            attn_flops += 12 * h * q * effective_seq
        num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
        return num_flops_per_token
    def num_scaling_params(self):
        nparams = sum(p.numel() for p in self.parameters())
        return nparams
    def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
        model_dim = self.config.n_embd
        ddp, rank, local_rank, world_size = get_dist_info()
        embedding_params = list(self.transformer.wte.parameters())
        lm_head_params = list(self.lm_head.parameters())
        if self.use_hyper_connections:
            matrix_params = []
            hyper_conn_params = []
            for layer in self.transformer.h:
                attn = layer["attn"].branch.attn
                matrix_params.extend([attn.c_q.weight, attn.c_k.weight, attn.c_v.weight, attn.c_proj.weight])
                mlp = layer["mlp"].branch.mlp
                matrix_params.extend([mlp.c_fc.weight, mlp.c_proj.weight])
                for name, param in layer["attn"].named_parameters():
                    if not name.startswith("branch."):
                        hyper_conn_params.append(param)
                for name, param in layer["mlp"].named_parameters():
                    if not name.startswith("branch."):
                        hyper_conn_params.append(param)
            if hasattr(self.expand_streams, 'parameters'):
                hyper_conn_params.extend(list(self.expand_streams.parameters()))
            if hasattr(self.reduce_streams, 'parameters'):
                hyper_conn_params.extend(list(self.reduce_streams.parameters()))
            all_param_count = sum(p.numel() for p in self.parameters())
            grouped_param_count = sum(p.numel() for p in matrix_params + embedding_params + lm_head_params + hyper_conn_params)
            assert all_param_count == grouped_param_count, f"Parameter count mismatch: {all_param_count} vs {grouped_param_count}"
        else:
            matrix_params = list(self.transformer.h.parameters())
            resid_params = [self.resid_lambdas]
            x0_params = [self.x0_lambdas]
            assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
        dmodel_lr_scale = (model_dim / 768) ** -0.5
        print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
        adam_groups = [
            dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
            dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
        ]
        if self.use_hyper_connections:
            if hyper_conn_params:
                adam_groups.append(dict(params=hyper_conn_params, lr=scalar_lr * 0.01))
        else:
            adam_groups.append(dict(params=resid_params, lr=scalar_lr * 0.01))
            adam_groups.append(dict(params=x0_params, lr=scalar_lr))
        adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0)
        AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
        adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
        muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
        MuonFactory = DistMuon if ddp else Muon
        muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
        optimizers = [adamw_optimizer, muon_optimizer]
        for opt in optimizers:
            for group in opt.param_groups:
                group["initial_lr"] = group["lr"]
        return optimizers
    def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
        B, T = idx.size()
        assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
        assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
        assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
        T0 = 0 if kv_cache is None else kv_cache.get_pos()
        cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
        x = self.transformer.wte(idx)
        x = norm(x)
        if self.use_hyper_connections:
            x = self.expand_streams(x)
            for i, layer in enumerate(self.transformer.h):
                x = layer["attn"](x, cos_sin, self.window_sizes[i], kv_cache)
                x = layer["mlp"](x)
            x = self.reduce_streams(x)
        else:
            x0 = x
            for i, block in enumerate(self.transformer.h):
                x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
                x = block(x, cos_sin, self.window_sizes[i], kv_cache)
        x = norm(x)
        softcap = 15
        logits = self.lm_head(x)
        logits = logits[..., :self.config.vocab_size]
        logits = logits.float()
        logits = softcap * torch.tanh(logits / softcap)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
            return loss
        else:
            return logits
    @torch.inference_mode()
    def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
        assert isinstance(tokens, list)
        device = self.get_device()
        rng = None
        if temperature > 0:
            rng = torch.Generator(device=device)
            rng.manual_seed(seed)
        ids = torch.tensor([tokens], dtype=torch.long, device=device)
        for _ in range(max_tokens):
            logits = self.forward(ids)
            logits = logits[:, -1, :]
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            if temperature > 0:
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
            else:
                next_ids = torch.argmax(logits, dim=-1, keepdim=True)
            ids = torch.cat((ids, next_ids), dim=1)
            token = next_ids.item()
            yield token