"""
Testing script for sequence models.
"""
import os
import sys
import argparse
import pickle
import torch
from tqdm import tqdm

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import GPTConfig, GPT

def encode(s, stoi):
    """Encode string to token indices"""
    tokens = s.split(" ")
    return [stoi[token] for token in tokens]

def decode(token_ids, itos):
    """Decode token indices to string"""
    return " ".join([itos[i] for i in token_ids])

def check_prediction(generated, expected):
    """Check if prediction matches expected output"""
    if '%' not in generated:
        return False
    
    parts = generated.split('%')
    if len(parts) < 2:
        return False
    
    prediction = parts[-1].strip()
    return prediction == expected

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt_iter', type=int, default=1000)
    parser.add_argument('--dataset', type=str, default='sequences')
    parser.add_argument('--config', type=str, default='1_1_120')
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--min_value', type=int, default=0)
    parser.add_argument('--max_value', type=int, default=100)
    parser.add_argument('--is_sorted', type=str, default="True")
    parser.add_argument('--num_copies', type=int, default=1)
    parser.add_argument('--embedding_config', type=str, default='')
    parser.add_argument('--test_samples', type=int, default=-1)
    parser.add_argument('--permutation_type', type=str, default="reversal")
    args = parser.parse_args()

    # Setup paths
    sequence_type = "sorted" if args.is_sorted == "True" else "unsorted"
    embedding_config = args.embedding_config
    full_config = f'{args.config}_{embedding_config}' if embedding_config else args.config
    
    data_path = f'data/{args.dataset}/{sequence_type}/{args.min_value}-{args.max_value}/{args.permutation_type}'
    meta_path = f'{data_path}/meta.pkl'
    
    # Load metadata
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    
    stoi, itos = meta['stoi'], meta['itos']
    max_new_tokens = meta['block_size']
    
    # Load model
    out_dir = f'out/{args.dataset}_{sequence_type}_{args.permutation_type}_{full_config}_{args.min_value}-{args.max_value}'
    
    if args.num_copies == 0:
        ckpt_path = os.path.join(out_dir, f'{args.ckpt_iter}_ckpt.pt')
    else:
        ckpt_path = os.path.join(out_dir, f'{args.ckpt_iter}_ckpt_{args.num_copies}.pt')
    
    print(f"Loading checkpoint: {ckpt_path}")
    checkpoint = torch.load(ckpt_path, map_location=args.device)
    
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    
    # Clean state dict
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    
    model.load_state_dict(state_dict)
    model.eval()
    model.to(args.device)
    
    print(f"Model: {gptconf.n_layer}L-{gptconf.n_head}H-{gptconf.n_embd}D")
    print(f"Identity embeddings: {getattr(gptconf, 'use_identity_embeddings', False)}")
    
    # Load test data
    test_file = f'{data_path}/test.txt'
    test_prompts = []
    test_expected = []
    
    with open(test_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line and '%' in line:
                parts = line.split('%')
                prompt = parts[0].strip() + ' %'
                expected = parts[1].strip()
                test_prompts.append(prompt)
                test_expected.append(expected)
    
    print(f"Loaded {len(test_prompts)} test examples")
    
    # Subsample if requested
    if args.test_samples > 0 and args.test_samples < len(test_prompts):
        import random
        random.seed(42)
        indices = random.sample(range(len(test_prompts)), args.test_samples)
        test_prompts = [test_prompts[i] for i in indices]
        test_expected = [test_expected[i] for i in indices]
        print(f"Testing on {args.test_samples} samples")
    
    # Encode prompts
    encoded_prompts = [encode(prompt, stoi) for prompt in test_prompts]
    
    # Test model
    total = 0
    correct = 0
    unique_sequences = set()
    unique_correct = set()
    
    results = []
    
    print("Running inference...")
    for i, (prompt, expected, encoded) in enumerate(tqdm(zip(test_prompts, test_expected, encoded_prompts))):
        # Generate
        input_tensor = torch.tensor(encoded, dtype=torch.long, device=args.device).unsqueeze(0)
        
        with torch.no_grad():
            generated_ids = model.generate(
                input_tensor, 
                max_new_tokens, 
                temperature=args.temperature,
                top_k=len(stoi)
            )
        
        # Decode
        generated_text = decode(generated_ids[0].tolist(), itos).split('\n')[0]
        
        # Check correctness
        is_correct = check_prediction(generated_text, expected)
        
        # Track statistics
        unique_seq = prompt.replace(' %', '')
        unique_sequences.add(unique_seq)
        
        total += 1
        if is_correct:
            correct += 1
            unique_correct.add(unique_seq)
        
        # Extract prediction
        prediction = generated_text.split('%')[-1].strip() if '%' in generated_text else 'no_output'
        
        results.append({
            'prompt': prompt,
            'expected': expected,
            'prediction': prediction,
            'correct': is_correct
        })
    
    # Calculate metrics
    accuracy = (correct / total * 100) if total > 0 else 0
    unique_accuracy = (len(unique_correct) / len(unique_sequences) * 100) if unique_sequences else 0
    
    # Save results
    output_file = os.path.join(out_dir, f'test_results_{args.ckpt_iter}.txt')
    with open(output_file, 'w') as f:
        for result in results:
            status = "" if result['correct'] else " [WRONG]"
            f.write(f"{result['prompt']} {result['prediction']}{status}\n")
        
        f.write(f"\n{'='*50}\n")
        f.write(f"RESULTS SUMMARY\n")
        f.write(f"{'='*50}\n")
        f.write(f"Total samples: {total}\n")
        f.write(f"Correct: {correct}\n")
        f.write(f"Accuracy: {accuracy:.2f}%\n")
        f.write(f"Unique sequences: {len(unique_sequences)}\n")
        f.write(f"Unique correct: {len(unique_correct)}\n")
        f.write(f"Unique accuracy: {unique_accuracy:.2f}%\n")
        f.write(f"Temperature: {args.temperature}\n")
    
    print(f"\nResults:")
    print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")
    print(f"Unique accuracy: {unique_accuracy:.2f}% ({len(unique_correct)}/{len(unique_sequences)})")
    print(f"Results saved to: {output_file}")

if __name__ == "__main__":
    main()