import torch
import numpy as np
import tiktoken
import argparse
import random
from model import GPT, eval_loss
import logging
import sys
import os

def setup_logging(log_file):
    try:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s [%(levelname)s] %(message)s",
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler(sys.stdout)
            ]
        )
        logging.info(f"Logging initialized to {log_file}")
    except Exception as e:
        print(f"Error setting up logging: {e}", file=sys.stderr)
        sys.exit(1)

def tokenize(s, enc):
    """
    Tokenize a string using the tokenizer.
    """
    tokens = [50304]
    tokens.extend(enc.encode_ordinary(s))
    tokens_np = np.array(tokens)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def sample_lines(file_path, n_samples=100):
    """Sample n_samples random lines from a text file"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        if len(lines) <= n_samples:
            return [line.strip() for line in lines]
        sampled_lines = random.sample(lines, n_samples)
        return [line.strip() for line in sampled_lines]
    except Exception as e:
        logging.error(f"Failed to read file {file_path}: {str(e)}")
        raise

def main():
    parser = argparse.ArgumentParser(description="Attribute Prediction Evaluation with GPT")
    parser.add_argument("-m", "--model", type=str, required=True, help="Load model from this path")
    parser.add_argument("-f", "--file_path", type=str, required=True, help="Path to profiles text file")
    args = parser.parse_args()

    if not os.path.exists(args.model):
        print(f"Error: Model file {args.model} does not exist", file=sys.stderr)
        sys.exit(1)
    if not os.path.exists(args.file_path):
        print(f"Error: Text file {args.file_path} does not exist", file=sys.stderr)
        sys.exit(1)

    setup_logging("default.log")  
    logging.info(f"Arguments: model={args.model}, file_path={args.file_path}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    logging.info(f"Loading model from {args.model}...")
    try:
        model = GPT.from_pretrained(args.model, device)
        logging.info("Model loaded successfully")
    except Exception as e:
        logging.error(f"Failed to load model: {str(e)}")
        sys.exit(1)

    # Load tokenizer
    enc = tiktoken.get_encoding("gpt2")
    logging.info("Tokenizer initialized")

    total_tokens = 0
    all_losses = []
    target_tokens = 1_000_000  # 1M tokens

    logging.info(f"Starting sampling process to reach {target_tokens:,} tokens...")

    while total_tokens < target_tokens:
        sampled_texts = sample_lines(args.file_path, 100)
        batch_tokens = sum(len(tokenize(text, enc)) for text in sampled_texts)

        logging.info(f"Calculating loss for batch (current tokens: {total_tokens:,})...")
        try:
            losses, token_log_probs = eval_loss(model, sampled_texts, enc)
            all_losses.extend(losses)
            total_tokens += batch_tokens
            logging.info(f"Batch processed. Total tokens so far: {total_tokens:,}")
        except Exception as e:
            logging.error(f"Error calculating loss: {str(e)}")
            break

    if all_losses:
        all_losses = np.array(all_losses) * -1
        mean_loss = np.mean(all_losses)
        min_loss = np.min(all_losses)
        max_loss = np.max(all_losses)

        # Log results
        logging.info("\nFinal Results:")
        logging.info(f"Total number of samples processed: {len(all_losses)}")
        logging.info(f"Total tokens processed: {total_tokens:,}")
        logging.info(f"Mean Loss: {mean_loss:.4f}")
        logging.info(f"Min Loss: {min_loss:.4f}")
        logging.info(f"Max Loss: {max_loss:.4f}")
    else:
        logging.warning("No losses calculated due to processing errors.")

if __name__ == "__main__":
    main()