# calc_utils.py

import torch
import gc
from tqdm import tqdm
from torch.func import functional_call, jacrev
from torch.amp import autocast
from torch.utils.data import DataLoader
from sklearn.linear_model import Ridge
from sklearn.preprocessing import OneHotEncoder
from transformers.tokenization_utils_base import BatchEncoding

def _prepare_batch_for_model(batch, device):
    """Prepares a batch for the model, assuming a dictionary-like structure."""
    if isinstance(batch, (dict, BatchEncoding)):
        return {k: v.to(device) for k, v in batch.items()}
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}. Expected a dict or BatchEncoding.")

def compute_kernel_gpu(model, initial_trainable_params, frozen_params,
                       dataset1, dataset2, data_collator, device, dtype, use_autocast,
                       chunk_size=16, transform_fn_J1=None, transform_fn_J2=None):
    """
    Computes a kernel matrix K = transform_J1(J1) @ transform_J2(J2).T in chunks entirely on the GPU.
    """
    loader1 = DataLoader(dataset1, batch_size=chunk_size, collate_fn=data_collator)
    loader2 = DataLoader(dataset2, batch_size=chunk_size, collate_fn=data_collator)

    # --- Reusable stateless function for Jacobian calculation ---
    def fnet_classifier_stateless(params, sample_batch):
        all_params = {**frozen_params, **params}
        with autocast(device_type=device.type, dtype=dtype, enabled=use_autocast):
            output = functional_call(model, all_params, kwargs=sample_batch)
        return output.logits # Hugging Face models return an object with a .logits attribute

    jac_computer = jacrev(fnet_classifier_stateless, argnums=0)

    def get_jacobian_chunk(batch):
        prepared_batch = _prepare_batch_for_model(batch, device)
        jac_list = []

        batch_size = prepared_batch['input_ids'].size(0)
        for i in range(batch_size):
            # Create a batch for a single sample
            sample_input = {key: val[i].unsqueeze(0) for key, val in prepared_batch.items() if key != 'labels'}
            jac_dict = jac_computer(initial_trainable_params, sample_input)
            # Flatten and concatenate jacobians for all trainable parameters
            flat_jac = torch.cat([j.flatten(start_dim=1) for j in jac_dict.values()], dim=1)
            jac_list.append(flat_jac)

        return torch.cat(jac_list, dim=0)

    kernel_row_chunks_cpu = []

    for batch1 in tqdm(loader1, desc="Computing Kernel (Outer loop - J1)"):
        J1_chunk_gpu = get_jacobian_chunk(batch1)
        if transform_fn_J1:
            J1_chunk_gpu = transform_fn_J1(J1_chunk_gpu)

        kernel_col_chunks_cpu = []
        for batch2 in loader2:
            J2_chunk_gpu = get_jacobian_chunk(batch2)
            if transform_fn_J2:
                J2_chunk_gpu = transform_fn_J2(J2_chunk_gpu)

            with autocast(device_type=device.type, dtype=dtype, enabled=use_autocast):
                kernel_block_gpu = J1_chunk_gpu @ J2_chunk_gpu.T

            kernel_col_chunks_cpu.append(kernel_block_gpu.cpu())
            del J2_chunk_gpu, kernel_block_gpu

        kernel_row_chunks_cpu.append(torch.cat(kernel_col_chunks_cpu, dim=1))

        del J1_chunk_gpu
        gc.collect()
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    return torch.cat(kernel_row_chunks_cpu, dim=0).to(device)


@torch.no_grad()
def get_initial_predictions(model, initial_params, dataset, data_collator, device, dtype, use_autocast, batch_size=32):
    """ Gets the initial model predictions (f0) before fine-tuning. """
    loader = DataLoader(dataset, batch_size=batch_size, collate_fn=data_collator)
    preds = []

    def fnet_stateless_full(params, batch):
        with autocast(device_type=device.type, dtype=dtype, enabled=use_autocast):
            model_batch = {k: v for k, v in batch.items() if k != 'labels'}
            output = functional_call(model, params, kwargs=model_batch)
        return output.logits # Hugging Face models return an object with a .logits attribute

    for batch in loader:
        batch_for_model = _prepare_batch_for_model(batch, device)
        outputs = fnet_stateless_full(initial_params, batch_for_model)
        preds.append(outputs.cpu())

    return torch.cat(preds).to(device)


def perform_kernel_regression(K_eval_train, K_train_train, f0_eval, f0_train, y_train_cpu, device, dtype, use_autocast):
    """ Performs ridge regression using the computed kernels and initial predictions. """
    encoder = OneHotEncoder(sparse_output=False, categories='auto', handle_unknown='ignore')
    y_train_one_hot = torch.tensor(
        encoder.fit_transform(y_train_cpu.numpy().reshape(-1, 1)),
        dtype=torch.float32
    ).to(device)

    ridge = Ridge(alpha=1.0, fit_intercept=False, solver='svd')
    ridge.fit(
        K_train_train.to(torch.float32).cpu().numpy(),
        (y_train_one_hot - f0_train.to(torch.float32)).cpu().numpy()
    )

    alpha = torch.tensor(ridge.coef_.T, dtype=K_eval_train.dtype).to(device)

    with autocast(device_type=device.type, dtype=dtype, enabled=use_autocast):
        f_inf_eval = f0_eval + (K_eval_train @ alpha)

    y_pred_entk = torch.argmax(f_inf_eval, dim=-1)
    return y_pred_entk
