# main.py

import argparse
import random
import numpy as np
from tqdm import tqdm
import gc
import torch
from torch.optim import Adam
from torch.amp import GradScaler, autocast
from torch.func import functional_call, grad
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_scheduler, AutoConfig

import data_utils
import calc_utils

def set_seed(seed):
    """Sets the random seed for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_model_and_tokenizer(model_name, num_labels, device, tokenizer=None):
    """
    Loads the appropriate language model and tokenizer.
    """
    config = AutoConfig.from_pretrained(model_name)
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    config.pad_token_id = tokenizer.pad_token_id
    config.num_labels = num_labels

    model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config).to(device)
    
    return model, tokenizer

def _prepare_batch_for_model(batch, device):
    """Prepares a batch for the model, assuming a dict."""
    return {k: v.to(device) for k, v in batch.items()}

def biclip_transform(J, u, d):
    """ Applies the coordinate-wise BiClip transformation to a Jacobian tensor. """
    J_abs = torch.abs(J)
    J_sign = torch.sign(J)
    mask_lower = J_abs <= d
    mask_upper = J_abs >= u
    mask_middle = (J_abs > d) & (J_abs < u)
    J_clipped = torch.zeros_like(J)
    J_clipped[mask_lower] = J_sign[mask_lower] * d
    J_clipped[mask_upper] = J_sign[mask_upper] * u
    J_clipped[mask_middle] = J[mask_middle]
    return J_clipped

def train_one_epoch(model, train_loader, optimizer, scheduler, scaler, device, dtype, use_autocast):
    """ Fine-tunes the model for ONE epoch. """
    model.train()
    for batch in tqdm(train_loader, desc="Training Epoch"):
        batch = _prepare_batch_for_model(batch, device)

        with autocast(device_type=device.type, dtype=dtype, enabled=use_autocast):
            outputs = model(**batch)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

def evaluate_model(model, eval_loader, device, dtype, use_autocast):
    """Evaluates the model and returns accuracy."""
    model.eval()
    total_correct, total_samples = 0, 0
    with torch.no_grad():
        for batch in tqdm(eval_loader, desc="Evaluating"):
            batch = _prepare_batch_for_model(batch, device)
            labels = batch["labels"]
            with autocast(device_type=device.type, dtype=dtype, enabled=use_autocast):
                outputs = model(**{k: v for k,v in batch.items() if k!='labels'})
                logits = outputs.logits

            predictions = torch.argmax(logits, dim=-1)
            total_correct += (predictions == labels).sum().item()
            total_samples += len(labels)
    return total_correct / total_samples

def run_single_experiment(args, seed):
    """ Runs a complete experiment for a single seed and a single method. """
    set_seed(seed)

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    DTYPE = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[args.amp_dtype]

    print("Initializing tokenizer and config...")
    _, tokenizer = get_model_and_tokenizer(args.model_name, 3, DEVICE) # Dummy num_labels=3

    print("Loading and preprocessing data...")
    train_dataset, eval_dataset, num_labels, data_collator = data_utils.get_data(
        args.dataset_name, args.task_name, tokenizer, args
    )
    train_loader, eval_loader, train_subset, eval_subset = data_utils.create_dataloaders_and_subsets(
        train_dataset, eval_dataset, args, data_collator, num_labels
    )

    print(f"Initializing model with {num_labels} labels...")
    model, tokenizer = get_model_and_tokenizer(args.model_name, num_labels, DEVICE, tokenizer=tokenizer)
    initial_params = {k: v.clone().detach() for k, v in model.named_parameters()}

    use_autocast_flag = DEVICE.type == 'cuda' and DTYPE in [torch.float16, torch.bfloat16]
    use_scaler_flag = DEVICE.type == 'cuda' and DTYPE == torch.float16

    if args.method == 'finetune':
        print("\n[Executing Method: Fine-tuning]")
        optimizer = Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
        scaler = GradScaler(enabled=use_scaler_flag)
        num_training_steps = args.epochs * len(train_loader)
        lr_scheduler = get_scheduler(
            name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
        )

        for epoch in range(args.epochs):
            print(f"\n--- Epoch {epoch + 1}/{args.epochs} ---")
            train_one_epoch(model, train_loader, optimizer, lr_scheduler, scaler, DEVICE, DTYPE, use_autocast_flag)

        print("\n--- Training Finished ---")
        final_accuracy = evaluate_model(model, eval_loader, DEVICE, DTYPE, use_autocast_flag)
        print(f"Final Model Accuracy: {final_accuracy:.4f}")
        return final_accuracy

    else: # eNTK-based methods
        print(f"\n[Executing Method: eNTK with {args.method.upper()} Kernel]")
        
        model.eval()

        if args.model_name.startswith('roberta'):
            last_layer_idx = model.config.num_hidden_layers - 1
            trainable_keys_logic = lambda k: 'classifier' in k or f'encoder.layer.{last_layer_idx}' in k
        elif args.model_name.startswith('Qwen'):
            last_layer_idx = model.config.num_hidden_layers - 1
            trainable_keys_logic = lambda k: 'score' in k or f'model.layers.{last_layer_idx}' in k
        else:
            raise ValueError(f"Trainable layer logic not defined for {args.model_name}")

        trainable_params_for_jac = {k: v for k, v in model.named_parameters() if trainable_keys_logic(k)}
        initial_trainable_params = {k: v.clone().detach() for k, v in trainable_params_for_jac.items()}
        frozen_params = {k: v for k, v in initial_params.items() if k not in initial_trainable_params}

        print(f"Total parameters for Jacobian: {sum(p.numel() for p in initial_trainable_params.values()):,}")

        X_train_subset = train_subset
        y_train_cpu = torch.tensor([train_subset[i]['labels'] for i in range(len(train_subset))])
        X_eval_subset = eval_subset
        y_eval_cpu = torch.tensor([eval_subset[i]['labels'] for i in range(len(eval_subset))])

        print("Computing initial predictions (f0)...")
        f0_train = calc_utils.get_initial_predictions(model, initial_params, X_train_subset, data_collator, DEVICE, DTYPE, use_autocast_flag, args.batch_size)
        f0_eval = calc_utils.get_initial_predictions(model, initial_params, X_eval_subset, data_collator, DEVICE, DTYPE, use_autocast_flag, args.batch_size)

        print(f"Calculating for kernel: {args.method.upper()}")
        
        kernel_args = {
            "model": model, "initial_trainable_params": initial_trainable_params,
            "frozen_params": frozen_params, "data_collator": data_collator, "device": DEVICE,
            "dtype": DTYPE, "use_autocast": use_autocast_flag, "chunk_size": args.kernel_chunk_size
        }

        transform_fn_J2 = None
        m_avg_gpu, v_avg_gpu = None, None

        # Common logic for methods needing moment estimation
        if args.method in ['adam-ntk', 'adagrad-ntk', 'adamw-ntk', 'sgdm-ntk']:
            def get_average_moments():
                m_avg, v_avg = None, None
                num_batches = 0
                
                def fnet_loss(params, batch):
                    all_params = {**frozen_params, **params}
                    with autocast(device_type=DEVICE.type, dtype=DTYPE, enabled=use_autocast_flag):
                        return functional_call(model, all_params, kwargs=batch).loss

                grad_fn = grad(fnet_loss)
                for batch in tqdm(train_loader, desc="Estimating Preconditioner Moments"):
                    batch_gpu = _prepare_batch_for_model(batch, DEVICE)
                    grads = grad_fn(initial_trainable_params, batch_gpu)
                    flat_grads = torch.cat([g.flatten() for g in grads.values()])
                    if m_avg is None:
                        m_avg = torch.zeros_like(flat_grads)
                        v_avg = torch.zeros_like(flat_grads)
                    m_avg += flat_grads
                    v_avg += flat_grads**2
                    num_batches += 1
                return m_avg / num_batches, v_avg / num_batches

            m_avg_gpu, v_avg_gpu = get_average_moments()
        
        # Define the transformation function based on the method
        if args.method == 'sgd-ntk':
            transform_fn_J2 = None
        elif args.method == 'signgd-ntk':
            transform_fn_J2 = torch.sign
        elif args.method == 'biclip-ntk':
            transform_fn_J2 = lambda j: biclip_transform(j, u=args.biclip_upper, d=args.biclip_lower)
        elif args.method == 'sgdm-ntk':
            m_avg_repeated = m_avg_gpu.repeat(num_labels)
            def sgdm_transform(j):
                return args.beta1 * m_avg_repeated + (1 - args.beta1) * j
            transform_fn_J2 = sgdm_transform
        elif args.method == 'adagrad-ntk':
            v_avg_repeated = v_avg_gpu.repeat(num_labels)
            def adagrad_transform(j):
                return j / (torch.sqrt(v_avg_repeated) + 1e-8)
            transform_fn_J2 = adagrad_transform
        elif args.method == 'adam-ntk':
            m_avg_repeated = m_avg_gpu.repeat(num_labels)
            v_avg_repeated = v_avg_gpu.repeat(num_labels)
            def adam_transform(j):
                m_j = args.beta1 * m_avg_repeated + (1 - args.beta1) * j
                v_j = args.beta2 * v_avg_repeated + (1 - args.beta2) * (j**2)
                return m_j / (torch.sqrt(v_j) + 1e-8)
            transform_fn_J2 = adam_transform
        elif args.method == 'adamw-ntk':
            m_avg_repeated = m_avg_gpu.repeat(num_labels)
            v_avg_repeated = v_avg_gpu.repeat(num_labels)
            theta_0_flat = torch.cat([p.flatten() for p in initial_trainable_params.values()])
            theta_0_flat_repeated = theta_0_flat.repeat(num_labels)
            def adamw_transform(j):
                m_j = args.beta1 * m_avg_repeated + (1 - args.beta1) * j
                v_j = args.beta2 * v_avg_repeated + (1 - args.beta2) * (j**2)
                adam_part = m_j / (torch.sqrt(v_j) + 1e-8)
                return adam_part + args.weight_decay * theta_0_flat_repeated
            transform_fn_J2 = adamw_transform
            
        K_train_train = calc_utils.compute_kernel_gpu(**kernel_args, dataset1=X_train_subset, dataset2=X_train_subset, transform_fn_J2=transform_fn_J2)
        K_eval_train = calc_utils.compute_kernel_gpu(**kernel_args, dataset1=X_eval_subset, dataset2=X_train_subset, transform_fn_J2=transform_fn_J2)
        
        if args.method in ['adam-ntk', 'adagrad-ntk', 'adamw-ntk', 'sgdm-ntk']:
            if m_avg_gpu is not None: del m_avg_gpu
            if v_avg_gpu is not None: del v_avg_gpu

        y_pred_entk = calc_utils.perform_kernel_regression(K_eval_train, K_train_train, f0_eval, f0_train, y_train_cpu, DEVICE, DTYPE, use_autocast_flag)
        correct_entk = (y_pred_entk.cpu() == y_eval_cpu).sum().item()
        accuracy = correct_entk / len(y_eval_cpu)
        print(f"eNTK Prediction Accuracy: {accuracy:.4f}")

        del K_train_train, K_eval_train, f0_train, f0_eval
        gc.collect()
        if DEVICE.type == 'cuda': torch.cuda.empty_cache()

        return accuracy

def main():
    parser = argparse.ArgumentParser(description="Run fine-tuning and eNTK experiments for Language Models.")
    parser.add_argument('--model_name', type=str, default='roberta-base',
                        choices=['roberta-base', 'Qwen/Qwen2.5-0.5B'])
    parser.add_argument('--dataset_name', type=str, default='glue', choices=['glue'])
    parser.add_argument('--task_name', type=str, default='sst2',
                        help="GLUE task name (e.g., 'sst2', 'mnli').")

    parser.add_argument('--method', type=str, required=True,
                        choices=['finetune', 'sgd-ntk', 'sgdm-ntk', 'signgd-ntk', 'biclip-ntk', 'adam-ntk', 'adagrad-ntk', 'adamw-ntk'])
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--kernel_chunk_size', type=int, default=16, help="Chunk size for GPU kernel calculation.")
    parser.add_argument('--train_set_size', type=int, default=256)
    parser.add_argument('--val_set_size', type=int, default=512)
    parser.add_argument('--seed', type=int, nargs='+', default=[42], help='List of random seeds.')
    parser.add_argument('--amp_dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16', 'float32'])

    # Finetuning specific args (AdamW also uses weight_decay)
    parser.add_argument('--learning_rate', type=float, default=2e-5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--weight_decay', type=float, default=0.01)

    # Kernel-specific args
    parser.add_argument('--beta1', type=float, default=0.9, help="Beta1/Momentum for adam-ntk, adamw-ntk, and sgdm-ntk.")
    parser.add_argument('--beta2', type=float, default=0.999, help="Beta2 for adam-ntk/adamw-ntk.")
    parser.add_argument('--biclip_upper', type=float, default=1e-2, help="Upper bound for biclip-ntk.")
    parser.add_argument('--biclip_lower', type=float, default=1e-4, help="Lower bound for biclip-ntk.")

    parser.add_argument('--scramble_percentage', type=float, default=0.0,
                        help="For MNLI, the percentage of words (0.0 to 1.0) to scramble in the hypothesis sentence "
                             "to simulate task discrepancy.")

    args = parser.parse_args()

    accuracies = []
    for i, seed in enumerate(args.seed):
        print(f"\n--- Running Exp {i+1}/{len(args.seed)} for '{args.method.upper()}' on '{args.dataset_name}/{args.task_name}' with Seed: {seed} ---")
        accuracy = run_single_experiment(args, seed)
        accuracies.append(accuracy)

    mean_accuracy = np.mean(accuracies)
    std_accuracy = np.std(accuracies)

    print("\n" + "="*50)
    print(" " * 18 + "FINAL RESULTS")
    print("="*50)
    print(f"Model          : {args.model_name}")
    print(f"Dataset        : {args.dataset_name.upper()}/{args.task_name.upper()}")
    print(f"Method         : {args.method.upper()}")
    print(f"Runs (seeds)   : {len(args.seed)}")
    print("-" * 50)
    print(f"Final Accuracy : {mean_accuracy:.4f} ± {std_accuracy:.4f}")
    print("="*50)

if __name__ == '__main__':
    main()
