import torch
from torch.utils.data import DataLoader

from enums import LossFunction
from transform_optimization.loss_funcs import get_loss_function


def prepare_model_for_transform_training(model, shared_R1_learned_transform, R2_transforms):
    # AFTER building quantized blocks + transforms
    for p in model.parameters():
        p.requires_grad = False

    # Then unfreeze only the shared transform
    for p in shared_R1_learned_transform.parameters():
        p.requires_grad = True
    shared_R1_learned_transform.unfreeze()

    for _, block_R2_learned_transform in R2_transforms.items():
        for p in block_R2_learned_transform.parameters():
            p.requires_grad = True
        block_R2_learned_transform.unfreeze()


def build_loss_function(model, float_model, opt_config):
    # Note: pad_id is optional for some metrics (e.g., MSE) but required for others (e.g., KL)
    # The OutDistill class will validate if pad_id is required for the chosen metric
    loss_kwargs = {'pad_id': model.config.pad_token_id}
    if opt_config.loss_function == LossFunction.OUTPUT_DISTILLATION.value:
        loss_kwargs['t'] = opt_config.temperature
    elif opt_config.loss_function == LossFunction.UNEMBED_DISTILLATION.value:
        loss_kwargs['q_model'] = model
        loss_kwargs['embed_weight'] = model.model.embed_tokens.weight
    elif opt_config.loss_function == LossFunction.FLAT_Q_DISTILLATION.value:
        loss_kwargs['q_model'] = model
    else:
        raise NotImplementedError(f"Loss function {opt_config.loss_function} not implemented for RTN.")

    # Create loss function based on opt_config
    return get_loss_function(
        loss_function=opt_config.loss_function,
        distance_metric=opt_config.distance_metric,
        model=float_model,
        **loss_kwargs
    )


def prepare_train_dataloader(calibration_data, run_config):
    def _collate_squeeze(batch):
        # batch: list of [1, seqlen]
        x = torch.stack(batch, dim=0)  # [B, 1, seqlen]
        return x.squeeze(1)

    train_loader = DataLoader(
        calibration_data,
        batch_size=run_config.batch_size,
        shuffle=run_config.shuffle_calibration,
        drop_last=False,
        collate_fn=_collate_squeeze
    )

    return train_loader


def wrap_up_training(shared_R1_learned_transform, R2_transforms):
    shared_R1_learned_transform.freeze()

    for _, block_R2_learned_transform in R2_transforms.items():
        block_R2_learned_transform.freeze()