import abc

import torch
from torch import Tensor


class ModelWrapper:
    def __init__(self, model, embeddings, inputs, tokens, *args, **kwargs):
        self.model = model
        self.embeddings = embeddings
        self.inputs = inputs
        self.tokens = tokens

    @abc.abstractmethod
    def __call__(self, point: Tensor):
        pass

    @abc.abstractmethod
    def numel(self):
        pass


class TransformerGeneratorWrapper:
    def __init__(self, model, tokenizer, max_length=500, samples=10):
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = samples

    def __call__(self, embeddings: Tensor):
        batches = torch.split(embeddings, 100)
        generated_data = [
            [
                self.model.generate(
                    inputs_embeds=e.to(dtype=torch.float16),
                    max_length=self.max_length,
                    temperature=0.7,
                    top_k=50,
                    do_sample=True,
                    num_return_sequences=1,
                )
                for e in batches
            ]
            for _ in range(self.samples)
        ]
        text = [
            [
                self.tokenizer.decode(
                    result,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                )
                for batch in generated_sample
                for result in batch.tolist()
            ]
            for generated_sample in generated_data
        ]
        text = [[text[j][i] for j in range(len(text))] for i in range(len(text[0]))]
        return text


class SoftPromptTuningWrapper(ModelWrapper):
    def __call__(self, point: Tensor):
        if len(point.shape) > 1:
            point = point.unsqueeze(1)
            initial_embeddings = self.inputs.expand(point.shape[0], -1, -1)
        else:
            point = point.unsqueeze(0).unsqueeze(0)
            initial_embeddings = self.inputs
        embeddings = torch.cat((point, initial_embeddings), dim=1)
        return self.model(embeddings)

    def numel(self):
        return self.inputs.shape[-1]


class MultipleSoftTuningLora(ModelWrapper):
    def __init__(self, start_idx: int, tuning_len: int, rank: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.start_idx = start_idx
        self.tuning_len = tuning_len
        self.rank = rank

    def __call__(self, point: Tensor):
        lora_a = point[..., : self.tuning_len * self.rank].reshape(
            -1, self.tuning_len, self.rank
        )
        lora_b = point[..., self.tuning_len * self.rank :].reshape(
            -1, self.rank, self.inputs.shape[-1]
        )
        lora_update = lora_a @ lora_b
        embedding = self.inputs.detach()
        embedding = embedding.expand(
            point.shape[0] if len(point.shape) > 1 else 1, -1, -1
        )
        updated_inputs_embedding = (
            embedding[:, self.start_idx : self.start_idx + self.tuning_len]
            + lora_update
        )
        embedding = embedding.clone()
        embedding[:, self.start_idx : self.start_idx + self.tuning_len] = (
            updated_inputs_embedding.clone()
        )
        return self.model(embedding)

    def numel(self):
        return self.tuning_len * self.rank + self.rank * self.inputs.shape[-1]


class LoraEmbeddingWrapper(ModelWrapper):
    def __init__(self, *args, rank=4, **kwargs):
        super().__init__(*args, **kwargs)
        self.rank = rank

    def __call__(self, point: Tensor):
        lora_a = point[: self.in_dim * self.rank].reshape(self.in_dim, self.rank)
        lora_b = point[self.in_dim * self.rank :].reshape(self.rank, self.out_dim)
        base_embedding = self.embeddings(self.inputs)
        lora_update = (lora_a @ lora_b)[self.inputs]
        return base_embedding + lora_update

    @property
    def in_dim(self):
        return self.embeddings.weight.shape[0]

    @property
    def out_dim(self):
        return self.embeddings.weight.shape[1]

    def numel(self):
        return self.rank * self.in_dim + self.out_dim * self.rank


class ExtractCodeWrapper(ModelWrapper):
    def __init__(self, code_extractor, model: ModelWrapper, *args, **kwargs):
        super().__init__(
            model, model.embeddings, model.inputs, model.tokens, *args, **kwargs
        )
        self.code_extractor = code_extractor

    def __call__(self, point: Tensor):
        text = self.model(point)
        return [[self.code_extractor(t) for t in samples] for samples in text]

    def numel(self):
        return self.model.numel()


class LoraEmbedLayerWrapper(ModelWrapper):
    def __init__(self, *args, rank=4, **kwargs):
        super().__init__(*args, **kwargs)
        self.rank = rank
        self.initial_embeddings = self.model.model.base_model.embed_tokens.weight

    def __call__(self, point: Tensor):
        lora_a = point[..., : self.initial_embeddings.shape[0] * self.rank].reshape(
            -1, self.initial_embeddings.shape[0], self.rank
        )
        lora_b = point[..., self.initial_embeddings.shape[1] * self.rank :].reshape(
            -1, self.rank, self.initial_embeddings.shape[1]
        )
        lora_update = lora_a @ lora_b
        self.model.base_model.embed_tokens.weight = (
            self.initial_embeddings + lora_update
        )
        data_embeddings = self.model.base_model.embed_tokens(self.tokens)
        return self.model(data_embeddings)

    def numel(self):
        return (
            self.rank * self.initial_embeddings.shape[0]
            + self.rank * self.initial_embeddings.shape[1]
        )
