import os
import argparse
import torch
import json
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.data import get_dataset, prepare_batch
from src.metrics import calculate_nfn_scores, calculate_projection_scores, infer_task_from_scores, infer_task_from_projection_scores, infer_task_probs_from_scores
import numpy as np
import random

from src.helper import setup_logging, set_seed, ResultsDB
import logging


def main():
    parser = argparse.ArgumentParser(description="Classify tasks using alignment metrics.")
    parser.add_argument('--model', type=str, required=True, help='HuggingFace model handle')
    parser.add_argument('--setting', type=str, required=True, help='Setting')
    parser.add_argument('--batchsize', type=int, default=8, help='Batch size')
    parser.add_argument('--nbsamples', type=int, default=100, help='Number of samples to use from dataset')
    parser.add_argument('--seqlen', type=int, default=256, help='Sequence length')
    parser.add_argument('--output_dir', type=str, required=True, help='Directory to save results')
    parser.add_argument('--record_dir', type=str, default='exp_records', help='Directory to save experiment records')
    parser.add_argument('--dataset', type=str, default='gsm8k', help='Dataset to use for testing')
    parser.add_argument('--label', type=str, default='code', help='Label to use for classification')
    parser.add_argument('--baseline_dir', type=str, default='anchors', help='Directory containing baseline metrics')
    parser.add_argument('--method', type=str, default='mean', help='Method to use for distance calculation: mean, median, KL')
    parser.add_argument('--method_type', type=str, default='norm', choices=['norm', 'projection'], help='Method type: norm (Method 1) or projection (Method 2)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--n_layers', type=int, default=None, help='Number of hidden layers to collect information.')
    args = parser.parse_args()

    if args.setting == 'L1':
        baseline_tasks = ['code', 'text', 'math']
        args.record_db = os.path.join(args.record_dir, 'L1_record.duckdb')
    elif args.setting == 'L2:PLang':
        p_lang = args.dataset.split("magicoder:")[1]
        assert args.label == p_lang

        baseline_tasks = ['cpp', 'csharp', 'java', 'php',
                          'python', 'rust', 'shell', 'swift',
                          'typescript']
        test_size_map = {
            'cpp': 1945,
            'csharp': 1638,
            'java': 2158,
            'php': 894,
            'python': 35752,
            'rust': 2067,
            'shell': 420,
            'swift': 1761,
            'typescript': 2245,}
        args.nbsamples = min(test_size_map[p_lang], args.nbsamples)
        args.record_db = os.path.join(args.record_dir, 'L2_lang_record.duckdb')
    elif args.setting == 'L2:Math':
        topic = args.dataset.split("math:")[1]
        assert args.label == topic

        baseline_tasks = ['Algebra', 'Counting_&_Probability', 'Geometry',
                          'Intermediate_Algebra', 'Number_Theory', 'Prealgebra',
                          'Precalculus']
        test_size_map = {
            'Algebra': 2131,
            'Counting_&_Probability': 445,
            'Geometry': 549,
            'Intermediate_Algebra': 1398,
            'Number_Theory': 609,
            'Prealgebra': 1276,
            'Precalculus': 492,
        }
        args.nbsamples = min(test_size_map[topic], args.nbsamples)
        args.record_db = os.path.join(args.record_dir, 'L2_math_record.duckdb')
    else:
        raise KeyError(f'Unknown setting {args.setting}')

    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)
    log_dir = os.path.join('logs', 'classifier')
    setup_logging(log_dir)
    logging.info(args)
    logging.info(f'Baseline tasks: {baseline_tasks}')

    # Load model and tokenizer
    logging.info(f"Loading model: {args.model}")
    model = AutoModelForCausalLM.from_pretrained(args.model,
                                                 torch_dtype=torch.float16,
                                                 device_map="cuda")
    tokenizer = AutoTokenizer.from_pretrained(args.model,
                                              trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    layer_indices = list(range(model.config.num_hidden_layers))
    if args.n_layers is not None:
        assert args.n_layers <= len(layer_indices), f'n_layers must be <= {len(layer_indices)}'
        layer_indices = layer_indices[-args.n_layers:]
    args.n_layers = len(layer_indices)

    # Load test set
    logging.info(f"Loading test set: {args.dataset}")
    problems = get_dataset(args.dataset,
                           num_samples=args.nbsamples,
                           tokenizer=tokenizer,
                           split="test",
                           seed=args.seed)
    logging.info(f"Loaded {len(problems)} problems from {args.dataset}")
    logging.info(f'Sample Input: {problems[0]}')

    # Prepare batches
    batches = [problems[i:i+args.batchsize] for i in range(0, len(problems), args.batchsize)]
    logging.info(f"Processing {len(batches)} batches of size up to {args.batchsize}")

    # Load baselines (math, code)
    model_short = args.model.split('/')[-1]
    baseline_scores = {}
    method_suffix = '_proj' if args.method_type == 'projection' else ''
    for task in baseline_tasks:
        baseline_path = os.path.join(args.baseline_dir, f"{model_short}_{task}_{args.seed}{method_suffix}_metrics.json")
        if not os.path.exists(baseline_path):
            raise FileNotFoundError(f"Baseline file not found for {task}: {baseline_path}")
        with open(baseline_path, 'r') as f:
            baseline_scores[task] = json.load(f)
    if not baseline_scores:
        raise ValueError("No baseline scores found. Please provide baseline metrics in the baseline_dir.")

    label_to_idx = {k: i for i, k in enumerate(baseline_scores.keys())}

    # For each batch, calculate metrics and classify
    y_true = [label_to_idx[args.label]] * len(problems)
    y_pred, dist_list, metrics = [], [], None
    for i, batch_problems in tqdm(enumerate(batches), total=len(batches)):
        batch = prepare_batch(batch_problems, tokenizer, max_length=args.seqlen)
        
        if args.method_type == 'norm':
            metrics = calculate_nfn_scores(model = model,
                                           batch = batch,
                                           mode='test',
                                           allowed_layers=layer_indices)
            pred_task, distances = infer_task_from_scores(metrics, baseline_scores, method=args.method)     # (B,) (B, 3)
        elif args.method_type == 'projection':
            metrics = calculate_projection_scores(model = model,
                                                  batch = batch,
                                                  mode='test',
                                                  allowed_layers=layer_indices)
            pred_task, distances = infer_task_from_projection_scores(metrics, baseline_scores, method=args.method)     # (B,) (B, 3)

        y_pred.extend(pred_task.cpu().tolist())
        dist_list.append(distances)

    logging.info(f'Monitored Layers: {metrics.keys()}')

    # Compute accuracy
    y_true = torch.tensor(y_true).int()
    y_pred = torch.tensor(y_pred).int()
    accuracy = torch.mean((y_true==y_pred).float())
    logging.info(f"Classification accuracy: {accuracy:.3f}")

    # Save results
    results = {
        'accuracy': accuracy,
        'n_samples': len(y_true),
        'y_true': y_true,
        'y_pred': y_pred,
        'dist_list': torch.cat(dist_list, dim=0),
    }
    method_suffix = '_proj' if args.method_type == 'projection' else ''
    out_path = os.path.join(args.output_dir, f"{model_short}_{args.dataset}_{args.seed}_L{args.n_layers}_S{args.seqlen}{method_suffix}_classification_results.pt")
    torch.save(results, out_path)
    logging.info(f"Saved classification results to {out_path}")

    db = ResultsDB(args.record_db)
    db.log(model=args.model,
           dataset=args.dataset,
           method=args.method,
           method_type=args.method_type,
           n_layers=args.n_layers,
           seqlen=args.seqlen,
           seed=args.seed,  # use your parsed seed
           accuracy=accuracy,
           n_samples=len(y_true),
           batchsize=args.batchsize,
           nbsamples=len(problems),)

if __name__ == "__main__":
    main()
