import torch
import copy


class scaled_module(torch.nn.Module):
    def __init__(self, module, id_, total):
        super().__init__()
        self.module = module
        self.id = id_
        self.total = total
        self.alpha = 0.0

    def get_scales(self, x):
        scales = torch.ones_like(x)
        if x.shape[1] > 1:
            for i in range(x.shape[1]):
                t1 = torch.rand(1, dtype=torch.bfloat16).to(x.device) * self.alpha
                scales[:, i] = 1 + t1 if self.id < self.total-1 else 1 / 1
        return scales

    def forward(self, x, **kwargs):
        out = list(self.module(x, **kwargs))
        scales = self.get_scales(out[0])
        out[0] /= scales
        return out


class llms_split:
    def __init__(self, model, tokenizer, n_block=0, device="cuda:0", private_layers=1):
        self.device = device
        self.model = model.model
        self.model_attacker = None
        self.total = private_layers
        self.head = model.lm_head.to(self.device)
        self.embed_layer = self.model.embed_tokens
        self.tokenizer = tokenizer
        self._split_model(n_block)

    def _split_model(self, n_block):
        if n_block:
            layers = []
            layers2 = []
            for i in range(n_block):
                if i >= self.total:
                    layers.append(self.model.layers[i])
                else:
                    layers2.append(self.model.layers[i])
                    layers.append(scaled_module(self.model.layers[i], i, self.total))
            self.model.layers = torch.nn.Sequential(*layers[0:1])
            self.model.norm = torch.nn.Identity()
            self.model_attacker = copy.deepcopy(self.model).to(self.device)
            self.model_attacker.layers = torch.nn.Sequential(*layers2)
            self.model.layers = torch.nn.Sequential(*layers)
            self.model = self.model.to(self.device)

    def reset_alpha(self, alpha=0):
        print(f'--------- reset noise to {alpha} ---------')
        for i in range(self.total):
            self.model.layers[i].alpha = alpha

    def inference(self, x, use_private=False):
        if isinstance(x, str):
            model_inputs = self.tokenizer([x], return_tensors='pt').to(self.device)
        else:
            model_inputs = self.tokenizer(x, return_tensors='pt', padding=True).to(self.device)
        with torch.no_grad():
            if use_private:
                out = self.model(**model_inputs)['last_hidden_state']
            else:
                out = self.model_attacker(**model_inputs)['last_hidden_state']
        return out
