from typing import Callable, Sequence, Optional
from torch import nn
import torch
from transformers import BatchEncoding
from functools import partial
from layskip.utils.dictionaries import NAME2TRANSLATORS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class HFwrapper(nn.Module):

    def __init__(self, encoder, classifier):
        super().__init__()

        self.encoder = encoder
        self.classifier = classifier

    def freeze_encoder(self):
        self.encoder.eval()
        for param in self.encoder.parameters():
            param.requires_grad = False

    def encode(self, x: torch.Tensor):
        with torch.no_grad():
            x = self.encoder(x)

            if not isinstance(x, torch.Tensor):
                # x = x.hidden_states[-1]
                x = x.pooler_output

        # x = x[:, 0, :]

        return x

    def decode(self, x: torch.Tensor):
        x = self.classifier(x)
        return x

    def forward(self, x: BatchEncoding):
        x = self.encode(x)
        x = self.decode(x)

        return x


class NoEncoder(nn.Module):

    def __init__(self, embeddings):
        super().__init__()

        self.embeddings = embeddings

    def encode(self, x: torch.Tensor):
        return x

    def forward(self, x: BatchEncoding):
        x = self.encode(x)

        return x


class SkipModel(nn.Module):

    def __init__(self, encoder, skips, mode, precomputed_embeddings, translator_name):
        super().__init__()

        self.encoder = encoder
        self.skips = skips
        self.mode = mode
        self.precomputed_embeddings = precomputed_embeddings
        self.translator_name = translator_name

        self.check_skip_consistency()

        # pass only the 12 vit encoder layers
        self.filtered_layers_list: Sequence[IndexedLayer] = self.filter_layers(self.encoder.encoder.layer, self.skips)

        self.computed_skips: Sequence[IndexedLayer] = self.compute_skipping(
            self.precomputed_embeddings, self.skips, self.mode
        )

        self.final_layers_list = sorted(
            (self.filtered_layers_list + self.computed_skips), key=lambda layer: layer.index
        )

    def encode(self, x: BatchEncoding):
        hidden_states = self.encoder.embeddings(x)

        for layer in self.final_layers_list:
            hidden_states = layer(hidden_states)

        return hidden_states

    def forward(self, x: BatchEncoding):
        hidden_states = self.encode(x)

        sequence_output = self.encoder.layernorm(hidden_states)
        pooled_output = (
            self.encoder.pooler(sequence_output)
            if hasattr(self.encoder, "pooler") and self.encoder.pooler is not None
            else sequence_output[:, 0, :]  # dinov2 model
        )
        # pooled_output = self.encoder.pooler(sequence_output) if self.encoder.pooler is not None else None

        return pooled_output

    def check_skip_consistency(self):
        max_val = 0

        for a, b in sorted(self.skips):

            if a == b:
                raise ValueError(f"Skipping from {a} to {b} is invalid")

            if (a < max_val) or (b <= max_val):
                raise ValueError(f"Skips {sorted(self.skips)} overlaps")

            max_val = b

    def compute_skipping(self, precomputed_embeddings, skips, mode):
        computed_skips: Sequence[IndexedLayer] = []

        for skip in skips:
            translators = self.fit_translators(
                spaces_to_fit=precomputed_embeddings,
                skip_from=skip[0],
                skip_to=skip[1],
                mode=mode,
            )

            computed_skips.append(
                IndexedLayer(
                    index=skip[0] + 1,
                    layer=partial(
                        self.transform_similar_spaces,
                        translators=translators,
                        mode=mode,
                    ),
                    layer_name=f"skip_{skip[0]}_{skip[1]}",
                )
            )

        return computed_skips

    def fit_translators(self, spaces_to_fit, skip_from, skip_to, mode):

        dtype = torch.double

        x = spaces_to_fit[skip_from].to(dtype).to(device)
        y = spaces_to_fit[skip_to].to(dtype).to(device)
        sequence_length = x.shape[1]

        translators = []

        if mode == 1:

            translator_factory = NAME2TRANSLATORS[self.translator_name]
            translator = translator_factory()

            x = x.reshape(-1, x.shape[-1])
            y = y.reshape(-1, y.shape[-1])

            translator.fit(x=x, y=y)
            translators.append(translator)

        elif mode == 2:
            for i in range(sequence_length):

                translator_factory = NAME2TRANSLATORS[self.translator_name]
                translator = translator_factory()

                x_i = x[:, i, :]
                y_i = y[:, i, :]

                translators.append(translator.fit(x=x_i, y=y_i))

        return translators

    def transform_similar_spaces(self, current_space, translators, mode):

        dtype = torch.double

        x = current_space.to(dtype)
        original_shape = x.shape

        if mode == 1:
            x = x.reshape(-1, current_space.shape[-1])
            translator = translators[0]

            transformed_space = translator.transform(x.to(dtype))[0]
            transformed_space = transformed_space.reshape(original_shape)

        elif mode == 2:
            transformed_spaces = []
            for i in range(original_shape[1]):
                x_i = x[:, i, :]
                translator = translators[i]
                transformed_spaces.append(translator.transform(x_i.to(dtype))[0])

            transformed_space = torch.stack(transformed_spaces, dim=1)

        return transformed_space.to(torch.float)

    def filter_layers(self, layers, skips):

        filtered_layers: Sequence[IndexedLayer] = []
        skip_ranges = set()

        for start, end in skips:
            skip_ranges.update(range(start + 1, end + 1))

        for i, layer in enumerate(layers):
            if i not in skip_ranges:
                # layer variable is overwritten if not given as default (layer=layer)
                # https://stackoverflow.com/questions/2295290/what-do-lambda-function-closures-capture
                filtered_layers.append(
                    IndexedLayer(index=i, layer=lambda x, layer=layer: layer(x)[0], layer_name=f"layer_{i}")
                )

        return filtered_layers


class IndexedLayer:
    def __init__(self, index: int, layer: Callable, layer_name: Optional[str] = None):
        self.index = index
        self.layer = layer
        self.layer_name = layer_name

    def __call__(self, *args: torch.Any, **kwargs: torch.Any) -> torch.Any:
        return self.layer(*args, **kwargs)

    def __repr__(self) -> str:
        return f"IndexedLayer(index={self.index}, layer={self.layer.__class__.__name__ if self.layer_name is None else self.layer_name})"
