import json
import os
import pickle
from typing import Optional, List

import torch.nn
import transformers
from transformers import AutoTokenizer


# input_embs = model.get_input_embeddings()(x["input_ids"])
#
# and then
# model(*x, inputs_embeds=input_embs)
# This should also work for model.generate

def create_struct_prefix(*args, **kwargs):
    return StructuredPrefixEmbeddingModel(*args, **kwargs)

def load_struct_prefix_with_init(model_str: str,
                                  fst_tokenizer_path: str,
                                   tokenizer: AutoTokenizer,
                                   num_examples: int,
                                   prefix_length: int,
                                   random_selection: bool = False,
                                   fst_file_path:str = None,
                                   map_location = None,
                                   *args, **kwargs):
    import sip.fst_pretrain
    from sip.meta_loading import load_fst_jsonl
    machine_embedder = torch.load(os.path.join(model_str, "machine_embedder_params.pt"), map_location=map_location)

    num_states = machine_embedder.state_embeddings.num_embeddings

    if fst_file_path is None:
        # Open file from standard location.
        fst_file_path = os.path.join(model_str, "pretraining_sample.jsonl")

    data_loader = iter(load_fst_jsonl(fst_file_path, tokenizer, fst_tokenizer_path=fst_tokenizer_path, num_states=num_states,
                                      batch_size=num_examples, random_order=random_selection,
                                      max_len=prefix_length))
    batch = next(data_loader)

    activations, _ = machine_embedder.prepare_input(batch) #shape (batch, prefix length, embed dim)
    init = activations.mean(dim=0).unsqueeze(0)

    struct_prefix_model = StructuredPrefixEmbeddingModel(model_str, init.shape[1], *args, **kwargs)
    struct_prefix_model.prefix_embedding = torch.nn.Parameter(init.detach(), requires_grad=True)

    return struct_prefix_model


class StructuredPrefixEmbeddingModel(torch.nn.Module):

    def __init__(self, model_str: str,
                 prefix_length: int,
                 adapter_str: str = None,
                 init_strs: Optional[List[str]] = None,
                 tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
                 ignore_mismatched_sizes: bool = False,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

        self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_str,
                                                                        ignore_mismatched_sizes=ignore_mismatched_sizes)
        self.model_str = model_str

        self.prefix_length = prefix_length

        self.adapter_str = adapter_str

        if adapter_str:
            self.model.add_adapter("task_adapter", adapter_str, set_active=True)

        if init_strs is not None:
            assert tokenizer is not None
            toks = tokenizer(init_strs, padding="max_length", max_length=prefix_length, return_tensors="pt", truncation=True)["input_ids"] # shape (batch, seq_len)
            embeded_toks = self.model.get_input_embeddings()(toks).mean(0, keepdim=True) #shape (1, seq_len, embedding dim)
            self.prefix_embedding = torch.nn.Parameter(embeded_toks.detach())
        else:
            self.prefix_embedding = torch.nn.Parameter(torch.empty(1, self.prefix_length, self.model.get_input_embeddings().embedding_dim))
            torch.nn.init.normal_(self.prefix_embedding)

    def save_pretrained(self, path):
        """
        Saves adapter (if any) and definitely saves prefix embeddings;
        :param path:
        :return:
        """
        self.model.save_all_adapters(path)
        with open(os.path.join(path, "config.json"), "w") as f:
            json.dump({"model_str": self.model_str,
                       "prefix_length": self.prefix_length,
                       "adapters_to_load": list(self.model.active_adapters.flatten()) if self.model.active_adapters is not None else [],
                       }, f)
        # with open(os.path.join(path, "active_adapters.pickle"), "wb") as f:
        #     pickle.dump(self.model.active_adapters, f)
        torch.save(self.prefix_embedding, path+"/prefix_embedding.pt")


    @staticmethod
    def from_pretrained(path, add_adapter: Optional[str] = None):
        with open(os.path.join(path, "config.json")) as f:
            config = json.load(f)
        model = StructuredPrefixEmbeddingModel(model_str=config["model_str"], prefix_length=config["prefix_length"])
        assert len(config["adapters_to_load"]) <= 1

        for adapter_name in config["adapters_to_load"]:
            model.model.load_adapter(os.path.join(path, adapter_name), set_active=True)

        model.prefix_embedding = torch.load(os.path.join(path, "prefix_embedding.pt"),
                                                         map_location=torch.device('cpu') if not torch.cuda.is_available() else None) #(1, seq_length, embedding dim)
        assert model.prefix_length == model.prefix_embedding.shape[1]

        if add_adapter is not None:
            model.model.add_adapter("task_adapter", add_adapter, set_active=True)

        return model


    def get_output_embeddings(self):
        return self.model.get_output_embeddings()

    def dump_reps(self, group_name, h5f, **kwargs):
        """
        Dumps representations for encoder/decoder states
        :return:
        """
        self(**kwargs, )
        raise NotImplementedError()

    @property
    def device(self):
        return self.model.device

    def prepare_input(self, kwargs):
        """
        Prepends the prefix to the given input.
        :param kwargs:
        :return:
        """
        input_ids = kwargs["input_ids"]

        embedded_inputs = self.model.get_input_embeddings()(input_ids)

        batch_size = input_ids.shape[0]

        prefix = torch.repeat_interleave(self.prefix_embedding, batch_size, 0) #shape (batch, prefix length, embed dim)

        kwargs = dict(kwargs)

        embedded_inputs = torch.cat([prefix, embedded_inputs], dim=1)  # shape (batch, prefix + seq length, embed dim)

        del kwargs["input_ids"]
        kwargs["inputs_embeds"] = embedded_inputs

        if "attention_mask" in kwargs:
            ones = torch.ones((batch_size, self.prefix_length), device=embedded_inputs.device, dtype=kwargs["attention_mask"].dtype)
            input_mask = torch.cat([ones, kwargs["attention_mask"]], dim=1)
            kwargs["attention_mask"] = input_mask

        return kwargs

    def forward(self, **kwargs):
        return self.model(**self.prepare_input(kwargs))

    def generate(self, **kwargs):
        return self.model.generate(**self.prepare_input(kwargs))
