#!/usr/bin/env python
"""
Train GPT2/LLaMA models on unary sequence prediction tasks without positional encoding,
with configurable dropout.
"""

import os
import argparse
import random
import json
import math
from typing import List, Dict, Set, Optional, Tuple
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
import torch
import torch.nn as nn
import torch.distributed as dist

import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import GPT2Config, GPT2LMHeadModel, LlamaConfig, LlamaForCausalLM
from dataloder import *  # expects SPECIAL, IGNORE_INDEX, ARLMGptDataset, create_dataloaders, etc.


def set_seeds(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class ZeroPositionalEmbedding(nn.Module):
    """Zero positional embedding to disable positional information."""
    def __init__(self, hidden_size, max_positions):
        super().__init__()
        self.hidden_size = hidden_size
        self.max_positions = max_positions

    def forward(self, position_ids):
        return torch.zeros(
            (position_ids.size(0), position_ids.size(1), self.hidden_size),
            dtype=torch.float32,
            device=position_ids.device
        )


def causal_token_acc(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100):
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:]
    mask = labels_s.ne(ignore_index)
    preds = logits_s.argmax(dim=-1)
    num = (preds.eq(labels_s) & mask).sum().item()
    den = mask.sum().item()
    return num, den


def causal_seq_acc(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100):
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:]
    preds = logits_s.argmax(dim=-1)
    mask = labels_s.ne(ignore_index)
    num = 0
    den = 0
    B = labels_s.size(0)
    for b in range(B):
        m = mask[b]
        if not m.any():
            continue
        num += int((preds[b][m] == labels_s[b][m]).all().item())
        den += 1
    return num, den


@torch.no_grad()
def evaluate(model, dl, device, use_amp: bool = True):
    model.eval()
    total_loss = 0.0
    total_count = 0
    tok_n = tok_d = 0
    seq_n = seq_d = 0

    for batch in dl:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        bsz = batch["input_ids"].size(0)
        with torch.amp.autocast("cuda", enabled=use_amp):
            out = model(**batch)
            loss, logits = out.loss, out.logits
        total_loss += loss.item() * bsz
        total_count += bsz
        n, d = causal_token_acc(logits, batch["labels"], IGNORE_INDEX)
        tok_n += n; tok_d += d
        n, d = causal_seq_acc(logits, batch["labels"], IGNORE_INDEX)
        seq_n += n; seq_d += d

    avg_loss = (total_loss / total_count) if total_count else 0.0
    tok_acc = (tok_n / tok_d) if tok_d else 0.0
    seq_acc = (seq_n / seq_d) if seq_d else 0.0
    return avg_loss, tok_acc, seq_acc


def save_checkpoint(model, ds, save_dir: str, save_name: str = "model"):
    model_dir = os.path.join(save_dir, save_name)
    os.makedirs(model_dir, exist_ok=True)
    model.save_pretrained(model_dir)
    itos_list = [ds.itos[i] for i in range(len(ds.itos))]
    with open(os.path.join(model_dir, "itos.json"), "w", encoding="utf-8") as f:
        json.dump(itos_list, f, ensure_ascii=False)


def load_vocab(save_dir: str, save_name: str = "model") -> Tuple[Dict[str, int], Dict[int, str]]:
    model_dir = os.path.join(save_dir, save_name)
    itos_json = os.path.join(model_dir, "itos.json")
    if os.path.exists(itos_json):
        with open(itos_json, "r", encoding="utf-8") as f:
            itos_list = json.load(f)
        itos = {i: ch for i, ch in enumerate(itos_list)}
        stoi = {ch: i for i, ch in enumerate(itos_list)}
        return stoi, itos
    # fallback for old checkpoints
    stoi = torch.load(os.path.join(model_dir, "stoi.pt"))
    itos = torch.load(os.path.join(model_dir, "itos.pt"))
    return stoi, itos


def refresh_ds_special_ids(ds):
    for tok in SPECIAL.values():
        if tok not in ds.stoi:
            raise ValueError(f"Special token {repr(tok)} missing from vocab.")
    ds.pad_id = ds.stoi[SPECIAL["PAD"]]
    ds.sep_id = ds.stoi[SPECIAL["SEP"]]
    ds.eos_id = ds.stoi[SPECIAL["EOS"]]
    ds.bos_id = ds.stoi[SPECIAL["BOS"]]


def train_gpt2(
    n_embd=128,
    n_layer=2,
    n_head=4,
    batch_size=32,
    n_epochs=10,
    lr=3e-4,
    model='mul',
    save_dir="checkpoints",
    save_name="saved_model",
    seed=42,
    use_amp=True,
    resume=False,
    ddp=False,
    local_rank=0,
    world_size=1,
    max_steps: int = 30000,
    model_type='gpt2',
    dropout: float = 0.0,
):
    if ddp:
        if 'LOCAL_RANK' in os.environ:
            local_rank = int(os.environ['LOCAL_RANK'])
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    set_seeds(seed)

    train_dl, val_dls, train_ds, (val0_ds, val1_ds, val2_ds) = create_dataloaders(
        batch_size, model, ddp=ddp, local_rank=local_rank
    )
    val0_dl, val1_dl, val2_dl = val_dls

    # compute max over ALL samples (train + val splits)
    with torch.no_grad():
        max_len = max(
            len(s) for s in
            train_ds.samples + val0_ds.samples + val1_ds.samples + val2_ds.samples
        )
    print('Max fused length across splits:', max_len)

    ctx = 1024 if max_len <= 1024 else int(math.ceil(max_len / 64) * 64)
    model_path = os.path.join(save_dir, save_name)

    if resume and os.path.exists(model_path):
        if local_rank == 0:
            print(f"[INFO] Loading pretrained model from {model_path}")
        if model_type == 'llama':
            model_obj = LlamaForCausalLM.from_pretrained(model_path).to(device)
        else:
            model_obj = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    else:
        if model_type == 'llama':
            config = LlamaConfig(
                vocab_size=len(train_ds.stoi),
                max_position_embeddings=ctx,
                hidden_size=n_embd,
                num_hidden_layers=n_layer,
                num_attention_heads=n_head,
                intermediate_size=4 * n_embd,
                pad_token_id=train_ds.pad_id,
                eos_token_id=train_ds.eos_id,
                bos_token_id=train_ds.bos_id,
                # ---- dropout ----
                attention_dropout=dropout,
                hidden_dropout=dropout,
            )
            model_obj = LlamaForCausalLM(config).to(device)
            # Disable positional embeddings by zeroing them if present
            if hasattr(model_obj.model, 'embed_positions'):
                model_obj.model.embed_positions = ZeroPositionalEmbedding(
                    config.hidden_size, config.max_position_embeddings
                )
        else:  # default to gpt2
            config = GPT2Config(
                vocab_size=len(train_ds.stoi),
                n_positions=ctx,
                n_ctx=ctx,
                n_embd=n_embd,
                n_layer=n_layer,
                n_head=n_head,
                pad_token_id=train_ds.pad_id,
                # ---- dropout ----
                embd_pdrop=dropout,
                attn_pdrop=dropout,
                resid_pdrop=dropout,
            )
            model_obj = GPT2LMHeadModel(config).to(device)
            # disable/override positional embeddings
            model_obj.transformer.wpe = ZeroPositionalEmbedding(
                config.n_embd, config.n_positions
            )

    if ddp:
        model_ddp = DDP(model_obj, device_ids=[local_rank], output_device=local_rank)
        model_ref = model_ddp
        model_for_eval = model_ddp.module
    else:
        model_ref = model_obj
        model_for_eval = model_obj

    optimizer = AdamW(model_ref.parameters(), lr=lr, weight_decay=0.01)
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass

    best_val_avg_seq_acc = 0.0
    best_epoch = -1
    best_stats = {}
    train_losses = []
    val0_losses = []
    global_step = 0

    for epoch in tqdm(range(n_epochs), disable=(local_rank != 0)):
        if ddp:
            train_dl.sampler.set_epoch(epoch)

        model_ref.train()
        run_loss = 0.0
        run_cnt = 0
        tok_n = tok_d = 0
        seq_n = seq_d = 0

        for batch in train_dl:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            bsz = batch["input_ids"].size(0)

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=use_amp):
                out = model_ref(**batch)
                loss, logits = out.loss, out.logits
            if ddp:
                loss = loss / world_size
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_ref.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            global_step += 1

            run_loss += loss.item() * bsz
            run_cnt += bsz
            n, d = causal_token_acc(logits.detach(), batch["labels"], IGNORE_INDEX)
            tok_n += n; tok_d += d
            n, d = causal_seq_acc(logits.detach(), batch["labels"], IGNORE_INDEX)
            seq_n += n; seq_d += d

        train_loss = run_loss / run_cnt if run_cnt else 0.0
        train_tok = (tok_n / tok_d) if tok_d else 0.0
        train_seq = (seq_n / seq_d) if seq_d else 0.0

        # ----- validation on all three splits -----
        val0_loss, val0_tok, val0_seq = evaluate(model_for_eval, val0_dl, device, use_amp=use_amp)
        val1_loss, val1_tok, val1_seq = evaluate(model_for_eval, val1_dl, device, use_amp=use_amp)
        val2_loss, val2_tok, val2_seq = evaluate(model_for_eval, val2_dl, device, use_amp=use_amp)

        if local_rank == 0:
            train_losses.append(train_loss)
            val0_losses.append(val0_loss)

            avg_val_seq_acc = (val0_seq + val1_seq + val2_seq) / 3.0

            if avg_val_seq_acc > best_val_avg_seq_acc:
                best_val_avg_seq_acc = avg_val_seq_acc
                best_epoch = epoch
                best_stats = {
                    'train_loss': train_loss,
                    'train_tok': train_tok,
                    'train_seq': train_seq,
                    'val0_loss': val0_loss, 'val0_tok': val0_tok, 'val0_seq': val0_seq,
                    'val1_loss': val1_loss, 'val1_tok': val1_tok, 'val1_seq': val1_seq,
                    'val2_loss': val2_loss, 'val2_tok': val2_tok, 'val2_seq': val2_seq,
                }
                save_checkpoint(model_for_eval, train_ds, save_dir, save_name=save_name)
                print(f"Epoch {epoch} | step {global_step} | *** NEW BEST AVG SEQ ACC: {avg_val_seq_acc:.4f} ***")
                print(f"Train | loss {train_loss:.4f} | tok_acc {train_tok:.4f} | seq_acc {train_seq:.4f}")
                print(f"Val0/1/2 | {val0_seq:.4f}/{val1_seq:.4f}/{val2_seq:.4f}")
                print(f"[INFO] Saved checkpoint to {os.path.join(save_dir, save_name)}")
            else:
                print(
                    f"Epoch {epoch} | step {global_step} | "
                    f"loss {train_loss:.4f} | avg_val_seq_acc: {avg_val_seq_acc:.4f} "
                    f"(best: {best_val_avg_seq_acc:.4f})"
                )

            if avg_val_seq_acc >= 1.0:
                print(f"Early stopping: Average validation seq accuracy reached 100% at epoch {epoch}.")
                break

    # Post-training logging and plotting
    if local_rank == 0:
        summary_path = os.path.join('logs', f"{model}.txt")
        os.makedirs(os.path.dirname(summary_path), exist_ok=True)

        # Re-evaluate val0, val1 and val2 on final model
        val0_loss, val0_tok, val0_seq = evaluate(model_for_eval, val0_dl, device, use_amp=use_amp)
        val1_loss, val1_tok, val1_seq = evaluate(model_for_eval, val1_dl, device, use_amp=use_amp)
        val2_loss, val2_tok, val2_seq = evaluate(model_for_eval, val2_dl, device, use_amp=use_amp)

        avg_final_seq_acc = (val0_seq + val1_seq + val2_seq) / 3.0

        with open(summary_path, "a") as f:
            f.write(
                f'{n_embd}_{n_layer}_{n_head}_{lr}: '
                f"Best epoch: {best_epoch} | "
                f"Best avg seq_acc: {best_val_avg_seq_acc:.4f} | "
                f"Best val seq_acc (0/1/2): "
                f"{best_stats.get('val0_seq', 0):.4f}/"
                f"{best_stats.get('val1_seq', 0):.4f}/"
                f"{best_stats.get('val2_seq', 0):.4f}\n"
            )

        plt.figure()
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val0_losses, label='Val0 Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f'Training and Validation Loss ({model})')
        plt.legend()
        plt.grid(True)
        plot_path = os.path.join('logs', f"{model}_loss_curve.png")
        os.makedirs(os.path.dirname(plot_path), exist_ok=True)
        plt.savefig(plot_path)
        plt.close()
        print(f"[INFO] Saved loss curve to {plot_path}")

    if ddp:
        dist.destroy_process_group()
    return model_ref, train_ds


@torch.no_grad()
def greedy_decode(
    model, ds, input_str: str,
    max_new_tokens: Optional[int] = None,
    stop_on_eos: bool = True,
    min_tokens: int = 1
):
    model.eval()
    device = next(model.parameters()).device
    refresh_ds_special_ids(ds)

    # Check for out-of-vocabulary characters
    unknown = [ch for ch in input_str if ch not in ds.stoi]
    if unknown:
        raise ValueError(f"OOV chars {unknown}. Not in training vocab: {unknown}")

    # Build prefix: $ INPUT |
    prefix = SPECIAL["BOS"] + input_str + SPECIAL["SEP"]
    ids = [ds.stoi[ch] for ch in prefix]
    x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    attn = torch.ones_like(x, device=device)

    # Get max context length - handle both GPT2 and Llama configs
    if hasattr(model.config, 'n_positions'):
        max_ctx = model.config.n_positions
    elif hasattr(model.config, 'max_position_embeddings'):
        max_ctx = model.config.max_position_embeddings
    else:
        max_ctx = 2048

    if x.size(1) >= max_ctx:
        raise ValueError(f"Input prefix length {x.size(1)} exceeds model capacity max_ctx={max_ctx}.")

    if max_new_tokens is None:
        max_new_tokens = max_ctx - x.size(1)

    # Greedy decoding loop
    generated = 0
    for _ in range(max_new_tokens):
        out = model(input_ids=x, attention_mask=attn)
        next_id = torch.argmax(out.logits[:, -1, :], dim=-1)

        if stop_on_eos and generated >= min_tokens and next_id.item() == ds.eos_id:
            break

        x = torch.cat([x, next_id.unsqueeze(-1)], dim=1)
        attn = torch.cat([attn, torch.ones((1, 1), dtype=attn.dtype, device=device)], dim=1)
        generated += 1

        if x.size(1) >= max_ctx:
            break

    # Decode output
    full_decoded = "".join(ds.itos[int(t)] for t in x[0].tolist())
    after_sep = full_decoded.split(SPECIAL["SEP"], 1)[1] if SPECIAL["SEP"] in full_decoded else ""
    pred_output = after_sep.split(SPECIAL["EOS"], 1)[0] if SPECIAL["EOS"] in after_sep else after_sep
    return full_decoded, pred_output


def get_model_alphabet(model):
    """Get alphabet for specific model type."""
    alphabets = {
        'mul': ['a', 'b', 'c'],
        'dvd': ['a', 'b'],
        'gcd': ['a', 'b', 'c'],
        'exp': ['a', 'b'],
        'prime': ['a', 'b']
    }
    return alphabets.get(model, [])


def get_automaton(model):
    """Get automaton for specific model type."""
    from kcm import build_dvd_kcm, build_mul_kcm, build_gcd_kcm, build_exp_kcm, build_prime_unary_kcm
    automaton_map = {
        'mul': build_mul_kcm,
        'dvd': build_dvd_kcm,
        'gcd': build_gcd_kcm,
        'exp': build_exp_kcm,
        'prime': build_prime_unary_kcm
    }
    return automaton_map.get(model, build_prime_unary_kcm)()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='mul', choices=['mul', 'dvd', 'exp', 'prime', 'gcd'])
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'inference'])
    parser.add_argument('--save_dir', type=str, default='checkpoints')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--nhead', type=int, default=4)
    parser.add_argument('--nlayer', type=int, default=2)
    parser.add_argument('--nembd', type=int, default=128)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision')
    parser.add_argument('--resume', action='store_true', help='Continue training from checkpoints/{name}')
    parser.add_argument('--ddp', action='store_true', help='Enable DistributedDataParallel training')
    parser.add_argument('--local_rank', type=int, default=0, help='Local rank for DDP')
    parser.add_argument('--world_size', type=int, default=1, help='World size for DDP')
    parser.add_argument('--max_steps', type=int, default=30000, help='Maximum training steps')
    parser.add_argument('--model_type', type=str, default='llama', choices=['gpt2', 'llama'],
                        help='Model architecture type')
    parser.add_argument('--dropout', type=float, default=0.0,
                        help='Dropout rate for Transformer layers (attn/resid/embd or attention/hidden)')
    args = parser.parse_args()

    set_seeds(args.seed)
    use_amp = not args.no_amp
    alp = get_model_alphabet(args.model)
    auto = get_automaton(args.model)

    if args.mode == 'train':
        model_ref, ds = train_gpt2(
            n_embd=args.nembd,
            n_layer=args.nlayer,
            n_head=args.nhead,
            batch_size=args.batch_size,
            n_epochs=args.epochs,
            lr=args.lr,
            model=args.model,
            save_dir=args.save_dir,
            save_name=args.model,  # save under checkpoints/<model>/
            seed=args.seed,
            use_amp=use_amp,
            resume=args.resume,
            ddp=args.ddp,
            local_rank=args.local_rank,
            world_size=args.world_size,
            max_steps=args.max_steps,
            model_type=args.model_type,
            dropout=args.dropout,
        )
    else:
        # Inference mode
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_dir = os.path.join(args.save_dir, args.model)

        # Load model
        if args.model_type == 'llama':
            model = LlamaForCausalLM.from_pretrained(
                model_dir, torch_dtype=torch.float32, low_cpu_mem_usage=True
            )
        else:
            model = GPT2LMHeadModel.from_pretrained(
                model_dir, torch_dtype=torch.float32, low_cpu_mem_usage=True
            )

        model = model.to(device).eval()
        # use args.model as save_name (matches training)
        stoi, itos = load_vocab(args.save_dir, args.model)
        ds = ARLMGptDataset([""], [""], stoi, itos)  # Dummy dataset for tokenization

        print(f"Model loaded. Valid alphabet: {alp}")
        while True:
            try:
                s = input(f"Input string (only {','.join(alp)}). Empty to quit: ").strip()
                if not s:
                    break
                if not all(ch in alp for ch in s):
                    print(f"Invalid input. Only {alp} allowed.")
                    continue

                full, yhat = greedy_decode(model, ds, s)
                if auto is not None:
                    gold = auto.output_generator(s, alp)
                    print(f"Pred: {yhat}")
                    print(f"Gold: {gold}")
                else:
                    print(f"Pred: {yhat}")
            except (EOFError, KeyboardInterrupt):
                break
