import torch
import torch.nn as nn

class Policy(nn.Module):
    def __init__(self, base_params, gpu, init_val=0.1, max_mult=1, **kwargs):
        # Create learnable parameters.
        super().__init__()
        self.learnable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        self.enable_mask = True
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
                continue
            else:
                self.learnable_params[k] = torch.nn.Parameter(
                    data=(
                        torch.randn(
                            min(v.shape),
                            device=gpu,
                            dtype=torch.bfloat16,
                        )
                        * init_val
                    ),
                    requires_grad=True,
                )
                self.num_params += self.learnable_params[k].numel()
        print(f"#params={self.num_params}")
        self.learnable_params_list = list(self.learnable_params.values())
        self.trainable_params = self.learnable_params_list
        self.learnable_params_module_list = nn.ParameterList(self.learnable_params_list)

    def get_learnable_params(self, detach=False):
        return self.learnable_params

    def set_trainable_params_values(self, new_values):
        with torch.no_grad():
            for p, v in zip(self.trainable_params, new_values):
                p.data.copy_(v)

    def get_mask(self, p):
        if self.enable_mask:
            return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult
        else:
            return torch.ones_like(p).to(torch.bfloat16)
        
def compose_new_params(
    policy,
    param_name,
    decomposed_params,
    learnable_params,
):
    """Compose new parameters from decomposed parameters."""
    # mm = get_mask(learnable_params[param_name])
    mm = policy.get_mask(learnable_params[param_name])
    return (
        decomposed_params[f"{param_name}.U"]
        @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
        @ decomposed_params[f"{param_name}.V"].T
    ) * (
        decomposed_params[f"{param_name}.S"].sum()
        / (decomposed_params[f"{param_name}.S"] * mm).sum()
    )

def backward(
    policy,
    model,
    base_params,
    decomposed_params,
    learnable_params,
):
    """Backward pass."""
    keys_to_backprop = [k for k in base_params if 'layernorm' not in k and 'bias' not in k and 'embeddings' not in k and 'layrnorm' not in k and 'layer_norm' not in k]
    last_key = keys_to_backprop[-1]
    for k in keys_to_backprop[:-1]:
        compose_new_params(policy, k, decomposed_params, learnable_params).backward(
            model.get_parameter(k).grad, retain_graph=True
        )
    # release graph
    compose_new_params(policy, last_key, decomposed_params, learnable_params).backward(
        model.get_parameter(last_key).grad, retain_graph=False
    )


def apply_policy_to_model(model, policy, base_params, decomposed_params, learnable_params):
    updated_params = {}
    for k in base_params:
        if any(skip in k for skip in ['layernorm', 'bias', 'embeddings', 'layer_norm', 'layrnorm']):
            updated_params[k] = base_params[k]
            continue
        updated_params[k] = compose_new_params(policy, k, decomposed_params, learnable_params)
    model.load_state_dict(updated_params, strict=False)
    
    return model