import os
import torch
from .function_utils import get_paras_dict_by_name
import logging

def kronecker_matmul(x, hadL, hadR):
    """equivalent to
    
        had = torch.kron(hadL, hadR)
        x = x.reshape(-1, had.shape[0])
        x = x.matmul(had).reshape(init_shape)
    """
    init_shape = x.shape
    x = x.reshape(-1, hadL.shape[0], hadR.shape[0])
    x = torch.matmul(x, hadR)
    x = torch.matmul(hadL.T, x)
    return x.reshape(init_shape)

def save_parametrized_checkpoint(model, args):
    quanted_parameters = {}
    for i in range(len(model.model.layers)):
        layer = model.model.layers[i]
        quanted_parameters[i] = layer.state_dict()
    torch.save(quanted_parameters, os.path.join(args.exp_dir, f"parametrized_paras.pth"))
    logging.info("saved paramaters at {}".format(os.path.join(args.exp_dir, f"parametrized_paras.pth")))


def load_s2_parameters(args, model, path=None):
    if path is None:
        s2_parameters = torch.load(os.path.join(args.exp_dir, f"s2_parameters.pth"))
    else:
        s2_parameters = torch.load(os.path.join(path, f"s2_parameters.pth"))
    layers = model.transformer_blocks
    
    for i in range(len(s2_parameters.keys())):
        s2_param = s2_parameters[i]
        layers[i].load_state_dict(s2_param, strict=False)
    return model