import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
import numpy as np
import random
from tqdm import tqdm
import json
import csv
import os
from datetime import datetime

class Config:
    num_digits = 4  
    
    vocab = [
        "0","1","2","3","4","5","6","7","8","9",
        "+","="," ",
        "Ʌ",   
        "§"  
    ]
    pad_token_id = vocab.index("Ʌ")
    eos_token_id = vocab.index("§")

    d_model = 512
    n_heads = 8
    d_ff = d_model*2  
    dropout = 0.1
    max_len = 512  
    
    norm_structure = "pre_norm"
    norm_type = "LayerNorm"

    deepnorm_alpha = 10

    k_layers = 1
    l_loops = 32  
    test_loops = list(range(1,101))

    batch_size = 512  
    accumulation_steps = 1  #
    total_steps = 10000  
    learning_rate = 1e-4
    warmup_steps = 2000

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1

config = Config()

print(f"Using {config.num_gpus} GPU(s): {[torch.cuda.get_device_name(i) for i in range(config.num_gpus)]}")

class CharTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.c2i = {c: i for i, c in enumerate(vocab)}
        self.i2c = {i: c for i, c in enumerate(vocab)}
    
    def encode(self, s):
        return [self.c2i[c] for c in s]
    
    def decode(self, indices):
        return "".join([self.i2c[i] for i in indices if i < len(self.vocab)])

tokenizer = CharTokenizer(config.vocab)

class AdditionDataset(IterableDataset):
    def __init__(self, tokenizer, config, mode='train', total_samples=1000000, data_dir='data'):
        self.tokenizer = tokenizer
        self.config = config
        self.mode = mode
        self.total_samples = total_samples
        self.data_dir = data_dir
        
        if mode == "train":
            file_path = os.path.join(data_dir, "train_4digit.jsonl")
        else:
            file_path = os.path.join(data_dir, "test_4digit.jsonl")
        
        self.data_files = []
        if os.path.exists(file_path):
            self.data_files.append(file_path)
        else:
            print(f"Warning: Data file {file_path} not found")
        
        if len(self.data_files) == 0:
            raise ValueError(f"No data files found for mode={mode} in {data_dir}")
        
        print(f"Loading data from {len(self.data_files)} file(s): {self.data_files}")

        self.preload_threshold = 1000000 
        self.data = None
        self.data_size = None
        if total_samples < self.preload_threshold:
            print(f"Preloading {total_samples} samples into memory...")
            self._preload_data(total_samples)
            print(f"Preloaded {len(self.data)} samples.")
        else:
            print(f"Using streaming mode for {total_samples} samples (memory efficient)")
    
    def _parse_jsonl_line(self, line):
        """Parse a line from JSONL file and convert to model required format"""
        try:
            data = json.loads(line.strip())
            num1 = data["num1"]
            num2 = data["num2"]
            answer = data["answer"]
            
            expression = data.get("expression", f"{num1} + {num2}")
            input_str = expression.replace(" ", "") + "="  
            
            target_str = f" {str(answer)[::-1]}"
            full_text = input_str + target_str + "§"
            
            tokenized = self.tokenizer.encode(full_text)
            input_len = len(self.tokenizer.encode(input_str))
            
            labels = []
            for i in range(len(tokenized)):
                if i < input_len - 1:
                    labels.append(-100)
                elif i < len(tokenized) - 1:
                    labels.append(tokenized[i + 1])
                else:
                    labels.append(-100)
            
            return {
                "input_ids": torch.tensor(tokenized, dtype=torch.long),
                "labels": torch.tensor(labels, dtype=torch.long),
                "input_len": input_len,
                "ground_truth_num": answer
            }
        except (json.JSONDecodeError, KeyError, ValueError) as e:
            print(f"Error parsing line: {line[:50]}... Error: {e}")
            return None
    
    def _preload_data(self, max_samples):
        """Preload data to memory (for evaluation scenarios)"""
        self.data = []
        samples_per_file = max_samples // len(self.data_files)
        remainder = max_samples % len(self.data_files)
        
        for file_idx, file_path in enumerate(self.data_files):
            current_samples = samples_per_file + (1 if file_idx < remainder else 0)
            count = 0
            
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    for line in f:
                        if count >= current_samples:
                            break
                        sample = self._parse_jsonl_line(line)
                        if sample is not None:
                            self.data.append(sample)
                            count += 1
            except FileNotFoundError:
                print(f"Warning: File {file_path} not found, skipping")
                continue
        
        self.data_size = len(self.data)
    
    def _read_sample_from_file(self, file_path, line_offset=0):
        """Read sample from file at specified offset (for index-based access)"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for _ in range(line_offset):
                    line = f.readline()
                    if not line:
                        return None
                line = f.readline()
                if not line:
                    return None
                return self._parse_jsonl_line(line)
        except (FileNotFoundError, IOError) as e:
            print(f"Error reading file {file_path} at offset {line_offset}: {e}")
            return None
    
    def _get_file_line_count(self, file_path):
        """Get file line count (for cyclic access)"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return sum(1 for _ in f)
        except (FileNotFoundError, IOError):
            return 0
    
    def generate_sample(self, index):
        """Get sample by index (for evaluation scenarios)"""
        if self.data is not None:
            return self.data[index % self.data_size]
        else:
            if not hasattr(self, '_file_line_counts'):
                self._file_line_counts = [self._get_file_line_count(f) for f in self.data_files]
                self._total_file_lines = sum(self._file_line_counts)
            
            if self._total_file_lines == 0:
                raise ValueError("All data files are empty")
            
            file_idx = 0
            line_offset = index % self._total_file_lines
            for i, line_count in enumerate(self._file_line_counts):
                if line_offset < line_count:
                    file_idx = i
                    break
                line_offset -= line_count

            sample = self._read_sample_from_file(self.data_files[file_idx], line_offset)
            if sample is None:
                print("======Failed to read from corresponding file, returning first sample===============")
                sample = self._read_sample_from_file(self.data_files[0], 0)
            return sample

    def __iter__(self):
        """Iterator: stream data reading with cyclic support"""
        if self.data is not None:
            index = 0
            while True:
                yield self.data[index % self.data_size]
                index += 1
        else:
            file_handles = []
            file_line_counts = []
            
            if not hasattr(self, '_cached_file_line_counts'):
                self._cached_file_line_counts = {}
            
            for file_path in self.data_files:
                try:
                    f = open(file_path, 'r', encoding='utf-8')
                    file_handles.append(f)
                    if file_path not in self._cached_file_line_counts:
                        self._cached_file_line_counts[file_path] = self._get_file_line_count(file_path)
                    file_line_counts.append(self._cached_file_line_counts[file_path])
                except (FileNotFoundError, IOError) as e:
                    print(f"Warning: Cannot open file {file_path}: {e}")
                    file_handles.append(None)
                    file_line_counts.append(0)
            
            file_indices = list(range(len(file_handles)))
            current_file_idx = 0
            samples_yielded = 0
            
            try:
                while True:
                    if self.total_samples > 0 and samples_yielded >= self.total_samples:
                        for f in file_handles:
                            if f is not None:
                                f.seek(0)
                        samples_yielded = 0
                    
                    f = file_handles[current_file_idx]
                    if f is not None:
                        line = f.readline()
                        if line:
                            sample = self._parse_jsonl_line(line)
                            if sample is not None:
                                yield sample
                                samples_yielded += 1

                            current_file_idx = (current_file_idx + 1) % len(file_handles)
                        else:
                            f.seek(0)
                            line = f.readline()
                            if line:
                                sample = self._parse_jsonl_line(line)
                                if sample is not None:
                                    yield sample
                                    samples_yielded += 1
                            current_file_idx = (current_file_idx + 1) % len(file_handles)
                    else:
                        current_file_idx = (current_file_idx + 1) % len(file_handles)
            finally:
                for f in file_handles:
                    if f is not None:
                        f.close()

def collate_fn(batch):
    max_len = max([len(x["input_ids"]) for x in batch])
    input_ids = []
    labels = []
    
    for item in batch:
        curr_len = len(item["input_ids"])
        pad_len = max_len - curr_len
        ids = torch.cat([
            item["input_ids"],
            torch.tensor([config.pad_token_id] * pad_len, dtype=torch.long)
        ])
        input_ids.append(ids)
        
        lbs = torch.cat([
            item["labels"],
            torch.tensor([-100] * pad_len, dtype=torch.long)
        ])
        labels.append(lbs)
    
    return {
        "input_ids": torch.stack(input_ids),  
        "labels": torch.stack(labels)        
    }

class CausalSelfAttention(nn.Module):
    """GPT-style causal self-attention (manual implementation, not using F.scaled_dot_product_attention)"""
    def __init__(self, config):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        self.c_attn = nn.Linear(config.d_model, 3 * config.d_model)
        self.c_proj = nn.Linear(config.d_model, config.d_model)
        self.n_heads = config.n_heads
        self.d_head = config.d_model // config.n_heads
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x, mask=None):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(config.d_model, dim=2)
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / (self.d_head ** 0.5))
        
        if mask is None:
            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))
            att = att.masked_fill(~causal_mask, float('-inf'))
        else:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            elif mask.dim() == 3:
                mask = mask.unsqueeze(1)
            att = att.masked_fill(~mask, float('-inf'))
        
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        y = att @ v  
        
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    """GPT-style feedforward network"""
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.d_model, config.d_ff)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    """GPT-style Transformer block (Sandwich norm structure)"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        d_model = self.config.d_model

        def get_norm():
            if self.config.norm_type == "LayerNorm":
                return nn.LayerNorm(d_model)
            elif self.config.norm_type == "RMSNorm":
                return nn.RMSNorm(d_model) 
            elif self.config.norm_type == "SimpleNorm":
                return nn.LayerNorm(d_model, elementwise_affine=False)
            elif self.config.norm_type == "DeepNorm":
                return nn.LayerNorm(d_model)
            else:
                return nn.LayerNorm(d_model)
        
        if self.config.norm_structure == "pre_norm":
            self.ln_1_pre = get_norm()
            self.ln_2_pre = get_norm()

        if self.config.norm_structure == "post_norm":
            self.ln_1_post = get_norm()
            self.ln_2_post = get_norm()

        if self.config.norm_structure == "sandwich_branch":
            self.ln_1_pre = get_norm()
            self.ln_2_pre = get_norm()
            self.ln_1_post = get_norm()
            self.ln_2_post = get_norm()

        if self.config.norm_structure == "sandwich_dual":
            self.ln_1_pre = get_norm()
            self.ln_2_pre = get_norm()
            self.ln_1_post = get_norm()
            self.ln_2_post = get_norm()

        self.attn = CausalSelfAttention(self.config)
        self.mlp = MLP(self.config)
        
    def forward(self, x):
        if self.config.norm_structure == "pre_norm":
            x = x + self.attn(self.ln_1_pre(x))
            x = x + self.mlp(self.ln_2_pre(x))
        elif self.config.norm_structure == "post_norm":
            if self.config.norm_type == "DeepNorm":
                alpha = getattr(self.config, "deepnorm_alpha", 1.0)
                x = self.ln_1_post(x * alpha + self.attn(x))
                x = self.ln_2_post(x * alpha + self.mlp(x))
            else:
                x = self.ln_1_post(x + self.attn(x))
                x = self.ln_2_post(x + self.mlp(x))
        elif self.config.norm_structure == "sandwich_branch":
            x = x + self.ln_1_post(self.attn(self.ln_1_pre(x)))
            x = x + self.ln_2_post(self.mlp(self.ln_2_pre(x)))
        elif self.config.norm_structure == "sandwich_dual":
            x = self.ln_1_post(x + self.attn(self.ln_1_pre(x)))
            x = self.ln_2_post(x + self.mlp(self.ln_2_pre(x)))

        return x

class LoopedTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(len(config.vocab), config.d_model)
        self.position_embedding = nn.Embedding(config.max_len, config.d_model)
        self.dropout = nn.Dropout(config.dropout)

        self.layers = nn.ModuleList([Block(config) for _ in range(config.k_layers)])

        self.ln_f = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, len(config.vocab), bias=False)

        self.lm_head.weight = self.token_embedding.weight

        self.apply(self._init_weights)

        if self.config.norm_type == "DeepNorm":
            self._deepnorm_init() 

    def _deepnorm_init(self):
        """DeepNorm 的特殊初始化"""
        alpha = getattr(self.config, "deepnorm_alpha", 1.0)
        
        for layer in self.layers:
            if hasattr(layer.attn, 'c_proj'):
                gain = (0.5 * alpha) ** 0.5
                nn.init.xavier_normal_(layer.attn.c_proj.weight, gain=gain)
            
            if hasattr(layer.mlp, 'c_proj'):
                gain = (0.5 * alpha) ** 0.5
                nn.init.xavier_normal_(layer.mlp.c_proj.weight, gain=gain)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
        elif isinstance(module, nn.LayerNorm):
            if hasattr(module, 'weight') and module.weight is not None:
                torch.nn.init.ones_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.RMSNorm):
            if hasattr(module, 'weight') and module.weight is not None:
                torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None, num_loops=None):
        """
        Args:
            idx: Input token ids
            targets: Target labels (provided during training)
            num_loops: Loop count (if None, randomly sampled during training, use config.l_loops during evaluation)
        """
        B, T = idx.size()
        device = idx.device
        
        pos = torch.arange(0, T, dtype=torch.long, device=device)
        pos = torch.clamp(pos, 0, self.config.max_len - 1)
        
        tok_emb = self.token_embedding(idx)  
        pos_emb = self.position_embedding(pos)  
        pos_emb = pos_emb.unsqueeze(0).expand(B, -1, -1)
        x = tok_emb + pos_emb
        x = self.dropout(x)
        
        if num_loops is None:
            num_loops = self.config.l_loops
        
        for _ in range(num_loops):
            for layer in self.layers:
                x = layer(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-100,  
                reduction='mean'  
            )
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, num_loops=None):
        """
        Generation logic for evaluation:
        - Remove EOS stopping (EOS only used for sample marking during training, not needed for evaluation)
        - Fixed generation by max_new_tokens (true result length) to ensure complete output at once
        - Use greedy decoding only (paper does not mention sampling, avoid randomness affecting accuracy)
        - Use fixed loop count (config.l_loops) for evaluation
        """
        if num_loops is None:
            num_loops = self.config.l_loops
        idx = idx.to(config.device)
        for step in range(max_new_tokens):
            idx_cond = idx[:, -self.config.max_len:]
            logits, _ = self(idx_cond, targets=None, num_loops=num_loops) 

            logits = logits[:, -1, :]

            if step < 3:
                probs = F.softmax(logits, dim=-1)
                top_k_probs, top_k_indices = torch.topk(probs, k=5, dim=-1)
                top_k_tokens = [tokenizer.i2c[idx.item()] for idx in top_k_indices[0]]
                print(f"[Debug Step {step}] Top 5 tokens: {list(zip(top_k_tokens, top_k_probs[0].cpu().tolist()))}")
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

            if step < 3:
                generated_token = tokenizer.i2c[idx_next[0, 0].item()]
                print(f"[Debug Step {step}] Generated token: '{generated_token}' (idx={idx_next[0, 0].item()})")

            idx = torch.cat((idx, idx_next), dim=1)
        return idx

def save_training_loss(step, loss, avg_loss, lr, loss_file):
    """Save training loss to CSV file"""
    file_exists = os.path.exists(loss_file)
    with open(loss_file, 'a', newline='') as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(['step', 'loss', 'avg_loss', 'learning_rate', 'timestamp'])
        writer.writerow([step, f'{loss:.6f}', f'{avg_loss:.6f}', f'{lr:.8f}', datetime.now().isoformat()])

def save_evaluation_result(step, input_str, generated_text, generated_num, gt_text, gt_num, is_correct, eval_file):
    """Save evaluation results at intermediate checkpoints"""
    result = {
        'step': step,
        'input': input_str,
        'generated_text': generated_text,
        'generated_num': generated_num,
        'ground_truth_text': gt_text,
        'ground_truth_num': gt_num,
        'is_correct': is_correct,
        'timestamp': datetime.now().isoformat()
    }
    
    if os.path.exists(eval_file):
        with open(eval_file, 'r') as f:
            results = json.load(f)
    else:
        results = []
    
    results.append(result)
    
    # Save results
    with open(eval_file, 'w') as f:
        json.dump(results, f, indent=2)

def save_final_accuracy(accuracy, correct, total, acc_file):
    """Save final accuracy evaluation results"""
    result = {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'timestamp': datetime.now().isoformat()
    }
    
    if os.path.exists(acc_file):
        with open(acc_file, 'r') as f:
            results = json.load(f)
    else:
        results = []
    
    results.append(result)
    
    with open(acc_file, 'w') as f:
        json.dump(results, f, indent=2)
def extract_number(text):
    """
    Extract pure digits from generated text and convert to integer (handle non-digit characters and conversion failures)
    Note: Since target order is reversed (predict least significant digit first), need to reverse extracted digit string back
    Args:
        text: Model generated text (e.g., " 4031", "4031§", "abc123", etc., digit order is reversed)
    Returns:
        int or None: Extracted integer (reversed back to normal order), None if extraction fails
    """
    digits = [c for c in text if c.isdigit()]
    if not digits:  
        return None
    try:
        digit_str = "".join(digits)
        reversed_digit_str = digit_str[::-1]  # 
        return int(reversed_digit_str)
    except ValueError: 
        return None

def evaluate_sample(model, step=None, eval_file=None):
    model.eval()
    ds = AdditionDataset(tokenizer, config, mode='test', total_samples=1000, data_dir='data')
    sample = ds.generate_sample(0)
    
    input_str_len = sample["input_len"]
    ground_truth_ids = sample["input_ids"][input_str_len:]
    ground_truth_ids = ground_truth_ids[ground_truth_ids != config.eos_token_id]
    gt_text = tokenizer.decode(ground_truth_ids.tolist())

    gt_num = extract_number(gt_text)
    

    input_ids = sample["input_ids"][:input_str_len].unsqueeze(0).to(config.device)
    max_new_tokens = len(ground_truth_ids) 
    
    generated_ids = model.generate(input_ids, max_new_tokens=max_new_tokens)

    generated_result_ids = generated_ids[0, input_str_len:]
    generated_result_ids = generated_result_ids[generated_result_ids != config.eos_token_id]
    generated_text = tokenizer.decode(generated_result_ids.tolist())

    generated_num = extract_number(generated_text)
    
    is_correct = generated_num == gt_num if (generated_num is not None and gt_num is not None) else False
    
    input_str = tokenizer.decode(input_ids[0].tolist())
    print(f"\n[Step Check] Input: {input_str}")
    print(f"[Step Check] Model Output Text: {generated_text} → Extracted Number: {generated_num}")
    print(f"[Step Check] Ground Truth Text: {gt_text} → Extracted Number: {gt_num}")
    print(f"[Step Check] Result: {'CORRECT' if is_correct else 'INCORRECT'}")
    print("-" * 60)
    
    if step is not None and eval_file is not None:
        save_evaluation_result(step, input_str, generated_text, generated_num, gt_text, gt_num, is_correct, eval_file)
    
    return {
        'input': input_str,
        'generated_text': generated_text,
        'generated_num': generated_num,
        'gt_text': gt_text,
        'gt_num': gt_num,
        'is_correct': is_correct
    }

def evaluate_accuracy(model, n_samples=1000, acc_file=None):
    model = model.module if (config.num_gpus > 1 and hasattr(model, 'module')) else model
    model.eval()
    model.to(config.device)
    correct = 0
    ds = AdditionDataset(tokenizer, config, mode='test', total_samples=n_samples, data_dir='data')
    
    print(f"Evaluating on 4-digit addition (Total {n_samples} samples)...")
    with torch.no_grad():
        for idx in tqdm(range(n_samples)):
            sample = ds.generate_sample(idx)
            input_str_len = sample["input_len"]
            ground_truth_ids = sample["input_ids"][input_str_len:]
            
            ground_truth_ids = ground_truth_ids[ground_truth_ids != config.eos_token_id]
            gt_text = tokenizer.decode(ground_truth_ids.tolist())
            gt_num = extract_number(gt_text)
            if gt_num is None:
                continue

            prompt = sample["input_ids"][:input_str_len].unsqueeze(0).to(config.device)
            max_new_tokens = len(ground_truth_ids)
            generated_ids = model.generate(prompt, max_new_tokens=max_new_tokens)
            
            generated_result_ids = generated_ids[0, input_str_len:]
            generated_result_ids = generated_result_ids[generated_result_ids != config.eos_token_id]
            generated_text = tokenizer.decode(generated_result_ids.tolist())
            generated_num = extract_number(generated_text)
            if generated_num is None:  
                continue
            if generated_num == gt_num:
                correct += 1
    
    acc = correct / n_samples if n_samples > 0 else 0.0
    print(f"[Final Accuracy] 4-digit addition: {acc*100:.2f}% (Correct: {correct}/{n_samples})")
    
    if acc_file is not None:
        save_final_accuracy(acc, correct, n_samples, acc_file)
    
    return acc

import matplotlib.pyplot as plt

def evaluate_multiple_loops(model, test_loops, n_samples=1000):
    model = model.module if (config.num_gpus > 1 and hasattr(model, 'module')) else model
    model.eval()
    model.to(config.device)

    loop_acc = {}

    for L in test_loops:
        correct = 0
        ds = AdditionDataset(tokenizer, config, mode='test',
                             total_samples=n_samples, data_dir='data')

        print(f"\nEvaluating l_loops = {L}")
        with torch.no_grad():
            for idx in tqdm(range(n_samples)):
                sample = ds.generate_sample(idx)
                input_len = sample["input_len"]

                gt_ids = sample["input_ids"][input_len:]
                gt_ids = gt_ids[gt_ids != config.eos_token_id]
                gt_num = extract_number(tokenizer.decode(gt_ids.tolist()))
                if gt_num is None:
                    continue

                prompt = sample["input_ids"][:input_len].unsqueeze(0).to(config.device)
                max_new_tokens = len(gt_ids)

                gen_ids = model.generate(
                    prompt,
                    max_new_tokens=max_new_tokens,
                    num_loops=L
                )

                gen_ids = gen_ids[0, input_len:]
                gen_ids = gen_ids[gen_ids != config.eos_token_id]
                gen_num = extract_number(tokenizer.decode(gen_ids.tolist()))

                if gen_num == gt_num:
                    correct += 1

        acc = correct / n_samples
        loop_acc[L] = acc
        print(f"l_loops={L} | acc={acc:.4f}")

    return loop_acc
def plot_loop_accuracy(loop_acc, save_path=None):
    loops = sorted(loop_acc.keys())
    accs = [loop_acc[l] for l in loops]
    plt.figure()
    plt.plot(loops, accs, marker='o')
    plt.xlabel("l_loops")
    plt.ylabel("Accuracy")
    plt.title("Accuracy vs Loop Count")
    plt.grid(False)
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

        file_root, _ = os.path.splitext(save_path)
        csv_path = file_root + ".csv"
        
        try:
            with open(csv_path, 'w', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                writer.writerow(['Loop_Count', 'Accuracy'])
                for l, a in zip(loops, accs):
                    writer.writerow([l, a])
            print(f"数据已保存至: {csv_path}")
        except Exception as e:
            print(f"数据保存失败: {e}")
    plt.show()

if __name__ == "__main__":
    seed = 2
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    print("Starting Training...")
    model=LoopedTransformer(config)
    model_path="XXXXX.pt"
    model.load_state_dict(torch.load(model_path))
    print("\n=== Final Evaluation ===")
    save_dir = "XXXXX"
    print("\n=== Loop Sensitivity Evaluation ===")
    exp_tag = (
        f"k{config.k_layers}_"
        f"{config.norm_structure}_"
        f"{config.norm_type}"
    )
    loop_acc = evaluate_multiple_loops(
        model,
        test_loops=config.test_loops,
        n_samples=100
    )
    plot_loop_accuracy(
        loop_acc,
        save_path=os.path.join(save_dir, f"acc_vs_loops_{exp_tag}.png")
    )
    print("saved to:",os.path.join(save_dir, f"acc_vs_loops_{exp_tag}.png"))
