from quantization.transforms.transforms import build_transform



def non_learned_transforms_builder(model, quant_config, transform_kwargs):
    qkv_in_transform = build_transform(quant_config.transform_class, size=model.config.hidden_size, **transform_kwargs)
    o_in_transform = build_transform(quant_config.transform_class, size=model.config.hidden_size, **transform_kwargs)
    gate_up_in_transform = build_transform(quant_config.transform_class, size=model.config.hidden_size,
                                           **transform_kwargs)
    down_in_transform = build_transform(quant_config.transform_class, size=model.config.intermediate_size,
                                        **transform_kwargs)
    v_out_transform = None

    return qkv_in_transform, o_in_transform, gate_up_in_transform, down_in_transform, v_out_transform


def build_R1_learned_transform(model, opt_config, quant_config, device):
    if opt_config.single_transform_matrix:
        print("Using single shared learned transform matrix for all layers.")

        return build_transform(transform_class=quant_config.transform_class_r1,
                               block_size=model.config.hidden_size,
                               num_blocks=1,
                               init=quant_config.matrix_init,
                               parametrization=opt_config.mat_param,
                               block_diag_init=opt_config.block_diag_init,
                               divide_num_blocks=model.config.hidden_size // quant_config.group_size,
                               divide_block_size=quant_config.group_size,
                               device=device,
                               dtype=model.config.torch_dtype,
                               add_rand_noise=opt_config.add_rand_noise)

    else:
        return build_transform(transform_class=quant_config.transform_class_r1,
                               block_size=quant_config.group_size,
                               num_blocks=model.config.hidden_size // quant_config.group_size,
                               init=quant_config.matrix_init,
                               parametrization=opt_config.mat_param,
                               device=device,
                               dtype=model.config.torch_dtype,
                               add_rand_noise=opt_config.add_rand_noise)


def build_R2_learned_transform(model, opt_config, quant_config, head_dim, device):
    return build_transform(transform_class=quant_config.transform_class_r2,
                           block_size=head_dim,
                           num_blocks=1,
                           init=quant_config.matrix_init,
                           parametrization=opt_config.mat_param,
                           block_diag_init=False,
                           device=device,
                           dtype=model.config.torch_dtype,
                           add_rand_noise=opt_config.add_rand_noise)


def set_block_R1_learned_transforms(model, shared_R1_learned_transform, quant_config, device):
    assert shared_R1_learned_transform is not None, "Shared learned transform must be initialized."
    qkv_in_transform = shared_R1_learned_transform
    gate_up_in_transform = shared_R1_learned_transform

    down_in_transform = build_transform('hadamard', size=model.config.intermediate_size,
                                        device=device, group_size=quant_config.group_size)

    return qkv_in_transform, gate_up_in_transform, down_in_transform


def set_block_R2_learned_transforms(model, quant_config, opt_config, head_dim, device):
    assert head_dim is not None, "Head dimension must be initialized."
    block_R2_learned_transform = build_R2_learned_transform(model, opt_config, quant_config, head_dim, device)

    o_in_transform = block_R2_learned_transform
    v_out_transform = block_R2_learned_transform

    return o_in_transform, v_out_transform