import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.data import get_dataset, prepare_batch
from src.helper import setup_logging, set_seed
from src.metrics import calculate_nfn_scores, calculate_projection_scores
import json
import random
import numpy as np
import logging


def average_metrics(metrics_list, record_dist=False, method_type='norm'):
    # metrics_list: list of dicts {name: {'mean': x, 'std': y, 'median': z}} for norm
    #              or {name: {'mean_vec': x, 'var_vec': y}} for projection
    if not metrics_list:
        return {}
    all_keys = set()
    for m in metrics_list:
        all_keys.update(m.keys())
    avg_metrics = {}
    
    if method_type == 'norm':
        for k in all_keys:
            means = [m[k]['mean'] for m in metrics_list if k in m]
            stds = [m[k]['std'] for m in metrics_list if k in m]
            medians = [m[k]['median'] for m in metrics_list if k in m]
            avg_metrics[k] = {
                'mean': sum(means) / len(means) if means else 0.0,
                'std': sum(stds) / len(stds) if stds else 0.0,
                'median': sum(medians) / len(medians) if medians else 0.0,
            }
            if record_dist:
                dists = [list(m[k]['dist']) for m in metrics_list if k in m]
                avg_metrics[k]['dist'] = [item for sublist in dists for item in sublist]
    elif method_type == 'projection':
        for k in all_keys:
            mean_vecs = [m[k]['mean_vec'] for m in metrics_list if k in m]
            var_vecs = [m[k]['var_vec'] for m in metrics_list if k in m]
            if mean_vecs and var_vecs:
                # Convert to tensors for averaging
                import torch
                mean_vecs_tensor = torch.stack([torch.tensor(mv) if not isinstance(mv, torch.Tensor) else mv for mv in mean_vecs])
                var_vecs_tensor = torch.stack([torch.tensor(vv) if not isinstance(vv, torch.Tensor) else vv for vv in var_vecs])
                
                avg_metrics[k] = {
                    'mean_vec': mean_vecs_tensor.mean(dim=0).tolist(),
                    'var_vec': var_vecs_tensor.mean(dim=0).tolist(),
                }
                if record_dist:
                    projections = [m[k]['projections'] for m in metrics_list if k in m and 'projections' in m[k]]
                    if projections:
                        avg_metrics[k]['projections'] = np.concatenate(projections, axis=0).tolist()
    
    return avg_metrics

def main():
    parser = argparse.ArgumentParser(description="Compute and save baseline metrics for each dataset.")
    parser.add_argument('--model', type=str, required=True, help='HuggingFace model handle')
    parser.add_argument('--setting', type=str, required=True, help='Setting')
    parser.add_argument('--output_dir', type=str, required=True, help='Directory to save baseline metrics')
    parser.add_argument('--batchsize', type=int, default=8, help='Batch size')
    parser.add_argument('--nbsamples', type=int, default=100, help='Number of samples to use from each dataset')
    parser.add_argument('--seqlen', type=int, default=256, help='Sequence length')
    parser.add_argument('--record_dist', action='store_true', help='Record distribution of norms')
    parser.add_argument('--method_type', type=str, default='norm', choices=['norm', 'projection'], help='Method type: norm (Method 1) or projection (Method 2)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)

    log_dir = os.path.join('logs', 'baselines')
    setup_logging(log_dir)
    logging.info(args)

    # Load model and tokenizer
    logging.info(f"Loading model: {args.model}")
    model = AutoModelForCausalLM.from_pretrained(args.model,
                                                 torch_dtype=torch.float16,
                                                 device_map="cuda")
    tokenizer = AutoTokenizer.from_pretrained(args.model,
                                              trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Baseline datasets
    if args.setting == 'L1':
        dataset_map = {
            'code': 'magicoder',
            'text': 'mmlu_history',
            'math': 'gsm8k',
        }
    elif args.setting == 'L2:PLang':
        lang_list = ['cpp', 'csharp', 'java', 'php',
                     'python', 'rust', 'shell', 'swift',
                     'typescript']
        dataset_map = {lang: f'magicoder:{lang}' for lang in lang_list}
    elif args.setting == 'L2:Math':
        topic_list = ['Algebra', 'Counting_&_Probability', 'Geometry',
                     'Intermediate_Algebra', 'Number_Theory', 'Prealgebra',
                     'Precalculus']
        dataset_map = {topic: f'comp_math:{topic}' for topic in topic_list}
    else:
        raise KeyError(f'Unknown setting {args.setting}')

    logging.info(f'Dataset Map: {dataset_map}')

    for task, dataset_name in dataset_map.items():
        logging.info(f"\nProcessing baseline for task: {task}")
        problems = get_dataset(dataset_name = dataset_name,
                               num_samples = args.nbsamples,
                               tokenizer = tokenizer,
                               split = 'train',
                               seed = args.seed)
        # logging.info(f"Sampled {len(problems)} problems from {dataset_name}")
        # logging.info(f'Sample Input: {problems[0]}')
        batches = [problems[i:i+args.batchsize] for i in range(0, len(problems), args.batchsize)]
        all_metrics = []
        for i, batch_problems in enumerate(batches):
            batch = prepare_batch(batch_problems, tokenizer, max_length=args.seqlen)
            if args.method_type == 'norm':
                metrics = calculate_nfn_scores(model, batch, record_dist=args.record_dist)
            elif args.method_type == 'projection':
                metrics = calculate_projection_scores(model, batch, record_dist=args.record_dist)
            all_metrics.append(metrics)
        avg_metrics = average_metrics(all_metrics, record_dist=args.record_dist, method_type=args.method_type)
        method_suffix = '_proj' if args.method_type == 'projection' else ''
        metrics_path = os.path.join(args.output_dir, args.model.split('/')[-1] + '_' + task + '_' + str(args.seed) + method_suffix + '_metrics.json')
        with open(metrics_path, 'w') as f:
            json.dump(avg_metrics, f, indent=2)
        logging.info(f"Saved baseline metrics to {metrics_path}")

if __name__ == "__main__":
    main()
