import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
import numpy as np
import random
from tqdm import tqdm
import json
import csv
import os
from datetime import datetime

class Config:
    num_digits = 4  

    vocab = [
        "0","1","2","3","4","5","6","7","8","9",
        "+","="," ",
        "Ʌ",   
        "§"    
    ]
    pad_token_id = vocab.index("Ʌ")
    eos_token_id = vocab.index("§")

    d_model = 512
    n_heads = 8
    d_ff = d_model*2  
    dropout = 0.1
    max_len = 512  
    
    norm_structure = "post_norm"

    norm_type = "DeepNorm"

    deepnorm_alpha = 10

    k_layers = 1
    l_loops = 4   
    
    batch_size = 512  
    accumulation_steps = 1  
    total_steps = 10000 
    learning_rate = 1e-4
    warmup_steps = 2000

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1

config = Config()
print(f"Using {config.num_gpus} GPU(s): {[torch.cuda.get_device_name(i) for i in range(config.num_gpus)]}")

class CharTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.c2i = {c: i for i, c in enumerate(vocab)}
        self.i2c = {i: c for i, c in enumerate(vocab)}
    
    def encode(self, s):
        return [self.c2i[c] for c in s]
    
    def decode(self, indices):
        return "".join([self.i2c[i] for i in indices if i < len(self.vocab)])

tokenizer = CharTokenizer(config.vocab)

class AdditionDataset(IterableDataset):
    def __init__(self, tokenizer, config, mode='train', total_samples=1000000, data_dir='data'):
        self.tokenizer = tokenizer
        self.config = config
        self.mode = mode
        self.total_samples = total_samples
        self.data_dir = data_dir
        
        if mode == "train":
            file_path = os.path.join(data_dir, "train_4digit.jsonl")
        else:
            file_path = os.path.join(data_dir, "test_4digit.jsonl")
        
        self.data_files = []
        if os.path.exists(file_path):
            self.data_files.append(file_path)
        else:
            print(f"Warning: Data file {file_path} not found")
        
        if len(self.data_files) == 0:
            raise ValueError(f"No data files found for mode={mode} in {data_dir}")
        
        print(f"Loading data from {len(self.data_files)} file(s): {self.data_files}")

        self.preload_threshold = 1000000  
        self.data = None
        self.data_size = None
        if total_samples < self.preload_threshold:
            print(f"Preloading {total_samples} samples into memory...")
            self._preload_data(total_samples)
            print(f"Preloaded {len(self.data)} samples.")
        else:
            print(f"Using streaming mode for {total_samples} samples (memory efficient)")
    
    def _parse_jsonl_line(self, line):
        """Parse a line from JSONL file and convert to model required format"""
        try:
            data = json.loads(line.strip())
            num1 = data["num1"]
            num2 = data["num2"]
            answer = data["answer"]
            
            expression = data.get("expression", f"{num1} + {num2}")
            input_str = expression.replace(" ", "") + "="  
            
            target_str = f" {str(answer)[::-1]}"
            full_text = input_str + target_str + "§"
            
            tokenized = self.tokenizer.encode(full_text)
            input_len = len(self.tokenizer.encode(input_str))
            
            labels = []
            for i in range(len(tokenized)):
                if i < input_len - 1:
                    labels.append(-100)
                elif i < len(tokenized) - 1:
                    labels.append(tokenized[i + 1])
                else:
                    labels.append(-100)
            
            return {
                "input_ids": torch.tensor(tokenized, dtype=torch.long),
                "labels": torch.tensor(labels, dtype=torch.long),
                "input_len": input_len,
                "ground_truth_num": answer
            }
        except (json.JSONDecodeError, KeyError, ValueError) as e:
            print(f"Error parsing line: {line[:50]}... Error: {e}")
            return None
    
    def _preload_data(self, max_samples):
        """Preload data to memory (for evaluation scenarios)"""
        self.data = []
        samples_per_file = max_samples // len(self.data_files)
        remainder = max_samples % len(self.data_files)
        
        for file_idx, file_path in enumerate(self.data_files):
            current_samples = samples_per_file + (1 if file_idx < remainder else 0)
            count = 0
            
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    for line in f:
                        if count >= current_samples:
                            break
                        sample = self._parse_jsonl_line(line)
                        if sample is not None:
                            self.data.append(sample)
                            count += 1
            except FileNotFoundError:
                print(f"Warning: File {file_path} not found, skipping")
                continue
        
        self.data_size = len(self.data)
    
    def _read_sample_from_file(self, file_path, line_offset=0):
        """Read sample from file at specified offset (for index-based access)"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                # Skip previous lines
                for _ in range(line_offset):
                    line = f.readline()
                    if not line:
                        return None
                # Read target line
                line = f.readline()
                if not line:
                    return None
                return self._parse_jsonl_line(line)
        except (FileNotFoundError, IOError) as e:
            print(f"Error reading file {file_path} at offset {line_offset}: {e}")
            return None
    
    def _get_file_line_count(self, file_path):
        """Get file line count (for cyclic access)"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return sum(1 for _ in f)
        except (FileNotFoundError, IOError):
            return 0
    
    def generate_sample(self, index):
        """Get sample by index (for evaluation scenarios)"""
        if self.data is not None:
            return self.data[index % self.data_size]
        else:
            if not hasattr(self, '_file_line_counts'):
                self._file_line_counts = [self._get_file_line_count(f) for f in self.data_files]
                self._total_file_lines = sum(self._file_line_counts)
            
            if self._total_file_lines == 0:
                raise ValueError("All data files are empty")
            
            file_idx = 0
            line_offset = index % self._total_file_lines
            for i, line_count in enumerate(self._file_line_counts):
                if line_offset < line_count:
                    file_idx = i
                    break
                line_offset -= line_count

            sample = self._read_sample_from_file(self.data_files[file_idx], line_offset)
            if sample is None:
                print("======Failed to read from corresponding file, returning first sample===============")
                sample = self._read_sample_from_file(self.data_files[0], 0)
            return sample

    def __iter__(self):
        """Iterator: stream data reading with cyclic support"""
        if self.data is not None:
            index = 0
            while True:
                yield self.data[index % self.data_size]
                index += 1
        else:
            file_handles = []
            file_line_counts = []
            
            if not hasattr(self, '_cached_file_line_counts'):
                self._cached_file_line_counts = {}
            
            for file_path in self.data_files:
                try:
                    f = open(file_path, 'r', encoding='utf-8')
                    file_handles.append(f)
                    if file_path not in self._cached_file_line_counts:
                        self._cached_file_line_counts[file_path] = self._get_file_line_count(file_path)
                    file_line_counts.append(self._cached_file_line_counts[file_path])
                except (FileNotFoundError, IOError) as e:
                    print(f"Warning: Cannot open file {file_path}: {e}")
                    file_handles.append(None)
                    file_line_counts.append(0)
            
            file_indices = list(range(len(file_handles)))
            current_file_idx = 0
            samples_yielded = 0
            
            try:
                while True:
                    if self.total_samples > 0 and samples_yielded >= self.total_samples:
                        for f in file_handles:
                            if f is not None:
                                f.seek(0)
                        samples_yielded = 0
                    
                    f = file_handles[current_file_idx]
                    if f is not None:
                        line = f.readline()
                        if line:
                            sample = self._parse_jsonl_line(line)
                            if sample is not None:
                                yield sample
                                samples_yielded += 1
                            current_file_idx = (current_file_idx + 1) % len(file_handles)
                        else:
                            f.seek(0)
                            line = f.readline()
                            if line:
                                sample = self._parse_jsonl_line(line)
                                if sample is not None:
                                    yield sample
                                    samples_yielded += 1
                            current_file_idx = (current_file_idx + 1) % len(file_handles)
                    else:
                        current_file_idx = (current_file_idx + 1) % len(file_handles)
            finally:
                for f in file_handles:
                    if f is not None:
                        f.close()

def collate_fn(batch):
    max_len = max([len(x["input_ids"]) for x in batch])
    input_ids = []
    labels = []
    
    for item in batch:
        curr_len = len(item["input_ids"])
        pad_len = max_len - curr_len
        ids = torch.cat([
            item["input_ids"],
            torch.tensor([config.pad_token_id] * pad_len, dtype=torch.long)
        ])
        input_ids.append(ids)
        
        lbs = torch.cat([
            item["labels"],
            torch.tensor([-100] * pad_len, dtype=torch.long)
        ])
        labels.append(lbs)
    
    return {
        "input_ids": torch.stack(input_ids),  
        "labels": torch.stack(labels)       
    }
class CausalSelfAttention(nn.Module):
    """GPT-style causal self-attention (manual implementation, not using F.scaled_dot_product_attention)"""
    def __init__(self, config):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        self.c_attn = nn.Linear(config.d_model, 3 * config.d_model)
        self.c_proj = nn.Linear(config.d_model, config.d_model)
        self.n_heads = config.n_heads
        self.d_head = config.d_model // config.n_heads
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x, mask=None):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(config.d_model, dim=2)
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / (self.d_head ** 0.5))
        
        if mask is None:
            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))
            att = att.masked_fill(~causal_mask, float('-inf'))
        else:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            elif mask.dim() == 3:
                mask = mask.unsqueeze(1)
            att = att.masked_fill(~mask, float('-inf'))
        
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        
        y = att @ v 
        
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    """GPT-style feedforward network"""
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.d_model, config.d_ff)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    """GPT-style Transformer block (Sandwich norm structure)"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        d_model = self.config.d_model

        def get_norm():
            if self.config.norm_type == "LayerNorm":
                return nn.LayerNorm(d_model)
            elif self.config.norm_type == "RMSNorm":
                return nn.RMSNorm(d_model) 
            elif self.config.norm_type == "SimpleNorm":
                return nn.LayerNorm(d_model, elementwise_affine=False)
            elif self.config.norm_type == "DeepNorm":
                return nn.LayerNorm(d_model)
            else:
                return nn.LayerNorm(d_model)
        
        if self.config.norm_structure == "pre_norm":
            self.ln_1_pre = get_norm()
            self.ln_2_pre = get_norm()

        if self.config.norm_structure == "post_norm":
            self.ln_1_post = get_norm()
            self.ln_2_post = get_norm()

        if self.config.norm_structure == "sandwich_branch":
            self.ln_1_pre = get_norm()
            self.ln_2_pre = get_norm()
            self.ln_1_post = get_norm()
            self.ln_2_post = get_norm()

        if self.config.norm_structure == "sandwich_dual":
            self.ln_1_pre = get_norm()
            self.ln_2_pre = get_norm()
            self.ln_1_post = get_norm()
            self.ln_2_post = get_norm()

        self.attn = CausalSelfAttention(self.config)
        self.mlp = MLP(self.config)
        
    def forward(self, x):
        if self.config.norm_structure == "pre_norm":
            x = x + self.attn(self.ln_1_pre(x))
            x = x + self.mlp(self.ln_2_pre(x))
        elif self.config.norm_structure == "post_norm":
            if self.config.norm_type == "DeepNorm":
                alpha = getattr(self.config, "deepnorm_alpha", 1.0)
                x = self.ln_1_post(x * alpha + self.attn(x))
                x = self.ln_2_post(x * alpha + self.mlp(x))
            else:
                x = self.ln_1_post(x + self.attn(x))
                x = self.ln_2_post(x + self.mlp(x))
        elif self.config.norm_structure == "sandwich_branch":
            x = x + self.ln_1_post(self.attn(self.ln_1_pre(x)))
            x = x + self.ln_2_post(self.mlp(self.ln_2_pre(x)))
        elif self.config.norm_structure == "sandwich_dual":
            x = self.ln_1_post(x + self.attn(self.ln_1_pre(x)))
            x = self.ln_2_post(x + self.mlp(self.ln_2_pre(x)))

        return x

class LoopedTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(len(config.vocab), config.d_model)
        self.position_embedding = nn.Embedding(config.max_len, config.d_model)

        self.dropout = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([Block(config) for _ in range(config.k_layers)])

        self.ln_f = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, len(config.vocab), bias=False)

        self.lm_head.weight = self.token_embedding.weight

        self.apply(self._init_weights)

        if self.config.norm_type == "DeepNorm":
            self._deepnorm_init()  

    def _deepnorm_init(self):
        alpha = getattr(self.config, "deepnorm_alpha", 1.0)
        
        for layer in self.layers:
            if hasattr(layer.attn, 'c_proj'):
                gain = (0.5 * alpha) ** 0.5
                nn.init.xavier_normal_(layer.attn.c_proj.weight, gain=gain)
            
            if hasattr(layer.mlp, 'c_proj'):
                gain = (0.5 * alpha) ** 0.5
                nn.init.xavier_normal_(layer.mlp.c_proj.weight, gain=gain)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
        elif isinstance(module, nn.LayerNorm):
            if hasattr(module, 'weight') and module.weight is not None:
                torch.nn.init.ones_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.RMSNorm):
            if hasattr(module, 'weight') and module.weight is not None:
                torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None, num_loops=None):
        """
        Args:
            idx: Input token ids
            targets: Target labels (provided during training)
            num_loops: Loop count (if None, randomly sampled during training, use config.l_loops during evaluation)
        """
        B, T = idx.size()
        device = idx.device

        pos = torch.arange(0, T, dtype=torch.long, device=device)

        pos = torch.clamp(pos, 0, self.config.max_len - 1)
        

        tok_emb = self.token_embedding(idx)  
        pos_emb = self.position_embedding(pos) 

        pos_emb = pos_emb.unsqueeze(0).expand(B, -1, -1)
        x = tok_emb + pos_emb
        x = self.dropout(x)
        
        num_loops = self.config.l_loops
        
        for _ in range(num_loops):
            for layer in self.layers:
                x = layer(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:

            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-100,  
                reduction='mean' 
            )
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        """
        Generation logic for evaluation:
        - Remove EOS stopping (EOS only used for sample marking during training, not needed for evaluation)
        - Fixed generation by max_new_tokens (true result length) to ensure complete output at once
        - Use greedy decoding only (paper does not mention sampling, avoid randomness affecting accuracy)
        - Use fixed loop count (config.l_loops) for evaluation
        """
        idx = idx.to(config.device)

        for step in range(max_new_tokens):

            idx_cond = idx[:, -self.config.max_len:]

            logits, _ = self(idx_cond, targets=None, num_loops=self.config.l_loops)  

            logits = logits[:, -1, :]

            if step < 3:
                probs = F.softmax(logits, dim=-1)
                top_k_probs, top_k_indices = torch.topk(probs, k=5, dim=-1)
                top_k_tokens = [tokenizer.i2c[idx.item()] for idx in top_k_indices[0]]
                print(f"[Debug Step {step}] Top 5 tokens: {list(zip(top_k_tokens, top_k_probs[0].cpu().tolist()))}")

            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

            if step < 3:
                generated_token = tokenizer.i2c[idx_next[0, 0].item()]
                print(f"[Debug Step {step}] Generated token: '{generated_token}' (idx={idx_next[0, 0].item()})")

            idx = torch.cat((idx, idx_next), dim=1)
        return idx

def save_training_loss(step, loss, avg_loss, lr, loss_file):
    """Save training loss to CSV file"""
    file_exists = os.path.exists(loss_file)
    with open(loss_file, 'a', newline='') as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(['step', 'loss', 'avg_loss', 'learning_rate', 'timestamp'])
        writer.writerow([step, f'{loss:.6f}', f'{avg_loss:.6f}', f'{lr:.8f}', datetime.now().isoformat()])

def save_evaluation_result(step, input_str, generated_text, generated_num, gt_text, gt_num, is_correct, eval_file):
    """Save evaluation results at intermediate checkpoints"""
    result = {
        'step': step,
        'input': input_str,
        'generated_text': generated_text,
        'generated_num': generated_num,
        'ground_truth_text': gt_text,
        'ground_truth_num': gt_num,
        'is_correct': is_correct,
        'timestamp': datetime.now().isoformat()
    }
    
    if os.path.exists(eval_file):
        with open(eval_file, 'r') as f:
            results = json.load(f)
    else:
        results = []
    
    results.append(result)
    

    with open(eval_file, 'w') as f:
        json.dump(results, f, indent=2)

def save_final_accuracy(accuracy, correct, total, acc_file):
    """Save final accuracy evaluation results"""
    result = {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'timestamp': datetime.now().isoformat()
    }
    
    if os.path.exists(acc_file):
        with open(acc_file, 'r') as f:
            results = json.load(f)
    else:
        results = []
    
    results.append(result)
    
    # Save results
    with open(acc_file, 'w') as f:
        json.dump(results, f, indent=2)

def train():
    print(f"Initializing Looped Transformer: {config.k_layers} layers x {config.l_loops} loops (Effective depth: {config.k_layers * config.l_loops})")
    

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"formal_logs/training_results_k{config.k_layers}_L{config.l_loops}_{config.norm_structure}__{config.norm_type}_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)
    
    loss_file = os.path.join(save_dir, "training_loss.csv")
    eval_file = os.path.join(save_dir, "evaluation_samples.json")
    
    print(f"Training results will be saved to: {save_dir}")
    
    model = LoopedTransformer(config)
    if config.num_gpus > 1:
        model = nn.DataParallel(model, device_ids=list(range(config.num_gpus))).to(config.device)
    else:
        model = model.to(config.device)
    
    train_dataset = AdditionDataset(tokenizer, config, mode='train', total_samples=config.total_steps*config.batch_size, data_dir='data')
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        collate_fn=collate_fn,
        pin_memory=True, 
        num_workers=4     
    )
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,          
        betas=(0.9, 0.999), 
        eps=1e-8,          
        weight_decay=0.01,  
    )
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=config.total_steps
    )
    
    model.train()
    progress_bar = tqdm(range(config.total_steps), desc="Training")
    data_iter = iter(train_loader)
    loss_window = []
    for step in progress_bar:
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            batch = next(data_iter)
        
        input_ids = batch["input_ids"].to(config.device, non_blocking=True)
        targets = batch["labels"].to(config.device, non_blocking=True)
        
        logits, loss = model(input_ids, targets)
        
        if config.num_gpus > 1:
            loss = loss.mean()  
        
        total_loss = loss
        
        if step < 5:

            batch_size = targets.size(0)
            first_sample_input = input_ids[0]  
            first_sample_targets = targets[0]  
            valid_mask = (first_sample_targets != -100)
            valid_targets = first_sample_targets[valid_mask]
            if len(valid_targets) > 0:
                first_sample_logits = logits[0]  

                valid_logits = first_sample_logits[valid_mask]

                preds = torch.argmax(valid_logits, dim=-1)
                acc = (preds == valid_targets).float().mean()
                print(f"[Debug Step {step}] Loss: {loss.item():.6f}, Total Loss: {total_loss.item():.6f}, Accuracy on valid targets (first sample): {acc.item():.4f}")

                valid_indices = torch.where(valid_mask)[0]
                if len(valid_indices) > 0:
                    print(f"[Debug Step {step}] Label alignment check (first 3 valid positions):")
                    for idx_pos in valid_indices[:3]:
                        context_end = idx_pos + 1
                        context = first_sample_input[:context_end]
                        predicted_token_id = first_sample_targets[idx_pos].item()
                        context_str = tokenizer.decode(context.tolist())
                        predicted_token = tokenizer.i2c[predicted_token_id]
                        print(f"  Position {idx_pos}: context='{context_str}' -> should predict '{predicted_token}' (id={predicted_token_id})")

                if len(preds) > 0:
                    print(f"[Debug Step {step}] First 5 predictions: {[tokenizer.i2c[p.item()] for p in preds[:5]]}")
                    print(f"[Debug Step {step}] First 5 targets: {[tokenizer.i2c[t.item()] for t in valid_targets[:5]]}")
        

        total_loss = total_loss / config.accumulation_steps
        total_loss.backward()
        
        if (step + 1) % config.accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        loss_window.append(total_loss.item() * config.accumulation_steps)
        if len(loss_window) > 100:
            loss_window.pop(0)
        
        if step % 1 == 0:
            avg_loss = np.mean(loss_window)
            current_lr = scheduler.get_last_lr()[0]
            progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.6f}"})
            if step % 10 == 0:
                save_training_loss(step, total_loss.item() * config.accumulation_steps, avg_loss, current_lr, loss_file)
        
        if step % 500 == 0 and step > 0:
            eval_result = evaluate_sample(model.module if config.num_gpus > 1 else model, step, eval_file)
            model.train()
    
    model_save_path = os.path.join(save_dir, f"looped_transformer_k{config.k_layers}_L{config.l_loops}_{config.norm_structure}_{config.norm_type}_4digit_addition.pt")
    print(f"\nSaving model to {model_save_path}...")
    torch.save(
        model.module.state_dict() if config.num_gpus > 1 else model.state_dict(),
        model_save_path
    )
    print("Model saved successfully.")
    
    config_info = {
        'k_layers': config.k_layers,
        'l_loops': config.l_loops,
        'norm_structure': config.norm_structure,
        'norm_type': config.norm_type,
        'deepnorm_alpha': config.deepnorm_alpha,
        'd_model': config.d_model,
        'n_heads': config.n_heads,
        'd_ff': config.d_ff,
        'dropout': config.dropout,
        'batch_size': config.batch_size,
        'total_steps': config.total_steps,
        'learning_rate': config.learning_rate,
        'warmup_steps': config.warmup_steps,
        'num_digits': config.num_digits,
        'num_gpus': config.num_gpus,
        'timestamp': timestamp
    }
    config_file = os.path.join(save_dir, "config.json")
    with open(config_file, 'w') as f:
        json.dump(config_info, f, indent=2)
    
    return model, save_dir
def extract_number(text):
    """
    Extract pure digits from generated text and convert to integer (handle non-digit characters and conversion failures)
    Note: Since target order is reversed (predict least significant digit first), need to reverse extracted digit string back
    Args:
        text: Model generated text (e.g., " 4031", "4031§", "abc123", etc., digit order is reversed)
    Returns:
        int or None: Extracted integer (reversed back to normal order), None if extraction fails
    """
    digits = [c for c in text if c.isdigit()]
    if not digits: 
        return None
    try:
        digit_str = "".join(digits)
        reversed_digit_str = digit_str[::-1]  
        return int(reversed_digit_str)
    except ValueError:  
        return None

def evaluate_sample(model, step=None, eval_file=None):
    model.eval()
    ds = AdditionDataset(tokenizer, config, mode='test', total_samples=1000, data_dir='data')
    sample = ds.generate_sample(0)
    
    input_str_len = sample["input_len"]
    ground_truth_ids = sample["input_ids"][input_str_len:]
    ground_truth_ids = ground_truth_ids[ground_truth_ids != config.eos_token_id]
    gt_text = tokenizer.decode(ground_truth_ids.tolist())

    gt_num = extract_number(gt_text)
    
    input_ids = sample["input_ids"][:input_str_len].unsqueeze(0).to(config.device)
    max_new_tokens = len(ground_truth_ids) 
    
    generated_ids = model.generate(input_ids, max_new_tokens=max_new_tokens)

    generated_result_ids = generated_ids[0, input_str_len:]
    generated_result_ids = generated_result_ids[generated_result_ids != config.eos_token_id]
    generated_text = tokenizer.decode(generated_result_ids.tolist())

    generated_num = extract_number(generated_text)
    
    is_correct = generated_num == gt_num if (generated_num is not None and gt_num is not None) else False
    
    input_str = tokenizer.decode(input_ids[0].tolist())
    print(f"\n[Step Check] Input: {input_str}")
    print(f"[Step Check] Model Output Text: {generated_text} → Extracted Number: {generated_num}")
    print(f"[Step Check] Ground Truth Text: {gt_text} → Extracted Number: {gt_num}")
    print(f"[Step Check] Result: {'CORRECT' if is_correct else 'INCORRECT'}")
    print("-" * 60)
    
    if step is not None and eval_file is not None:
        save_evaluation_result(step, input_str, generated_text, generated_num, gt_text, gt_num, is_correct, eval_file)
    
    return {
        'input': input_str,
        'generated_text': generated_text,
        'generated_num': generated_num,
        'gt_text': gt_text,
        'gt_num': gt_num,
        'is_correct': is_correct
    }

def evaluate_accuracy(model, n_samples=1000, acc_file=None):
    model = model.module if (config.num_gpus > 1 and hasattr(model, 'module')) else model
    model.eval()
    model.to(config.device)
    correct = 0
    ds = AdditionDataset(tokenizer, config, mode='test', total_samples=n_samples, data_dir='data')
    
    print(f"Evaluating on 4-digit addition (Total {n_samples} samples)...")
    with torch.no_grad():
        for idx in tqdm(range(n_samples)):
            sample = ds.generate_sample(idx)
            input_str_len = sample["input_len"]
            ground_truth_ids = sample["input_ids"][input_str_len:]
            
            ground_truth_ids = ground_truth_ids[ground_truth_ids != config.eos_token_id]
            gt_text = tokenizer.decode(ground_truth_ids.tolist())
            gt_num = extract_number(gt_text)
            if gt_num is None:  
                continue
            
            prompt = sample["input_ids"][:input_str_len].unsqueeze(0).to(config.device)
            max_new_tokens = len(ground_truth_ids)
            generated_ids = model.generate(prompt, max_new_tokens=max_new_tokens)
            
            generated_result_ids = generated_ids[0, input_str_len:]
            generated_result_ids = generated_result_ids[generated_result_ids != config.eos_token_id]
            generated_text = tokenizer.decode(generated_result_ids.tolist())
            generated_num = extract_number(generated_text)
            if generated_num is None:  
                continue
            if generated_num == gt_num:
                correct += 1
    
    acc = correct / n_samples if n_samples > 0 else 0.0
    print(f"[Final Accuracy] 4-digit addition: {acc*100:.2f}% (Correct: {correct}/{n_samples})")
    
    if acc_file is not None:
        save_final_accuracy(acc, correct, n_samples, acc_file)
    
    return acc

if __name__ == "__main__":
    seed = 1
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    print("Starting Training...")
    model, save_dir = train()
    # model=LoopedTransformer(config)
    # model_path = "XXXXX.pt"
    # model_path="XXXXX.pt"
    # model.load_state_dict(torch.load(model_path))
    print("\n=== Final Evaluation ===")
    acc_file = os.path.join(save_dir, "final_accuracy.json")
    evaluate_accuracy(model, n_samples=100, acc_file=acc_file)
    print(f"\nAll results saved to: {save_dir}")