""" Base class for all conditional graph generators from text"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from accelerate import Accelerator
import numpy as np
import numpy.typing as npt
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
    PreTrainedTokenizerFast
)
import yaml

from peft import LoraConfig, get_peft_model

from text2graph.data.base_dataset import TextGraph


@dataclass
class GeneratedText:
    """ A data class for representing a sequence of generated text """
    ids: torch.Tensor
    mask: torch.Tensor


class BaseModel(torch.nn.Module, ABC):
    """ Base class for all conditional graph generators from text """
    def __init__(self, metadata: Dict[str, Any]) -> None:
        assert 'randomize_sequence' in metadata
        super().__init__()
        self.metadata = metadata
        assert self.metadata.get("name"), f"model metadata must provide a name"
        assert self.metadata.get("model_dir"), f"model metadata must provide a local directory"
        lm_name = metadata['language_model_name']
        self.tokenizer = AutoTokenizer.from_pretrained(lm_name, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = 'left'
        self.metadata['vocab_size'] = len(self.tokenizer)
        self.language_model = AutoModelForCausalLM.from_pretrained(
            lm_name,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
            ) if self.metadata['quantize'] else None,
            trust_remote_code=True
        )
        if self.metadata['use_lora']:
            lora_config = {
                "lora_alpha": self.metadata["lora_alpha"],
                "lora_dropout": self.metadata["lora_dropout"],
                "r": self.metadata['lora_r'],
            }
            self.peft_config = LoraConfig(
                **lora_config,
                bias="none",
                task_type="CAUSAL_LM",
                target_modules=[
                    "query_key_value",
                    "dense",
                    "dense_h_to_4h",
                    "dense_4h_to_h",
                ]
            )
            self.language_model = get_peft_model(self.language_model, self.peft_config)
        self.has_encoder = hasattr(self.language_model, 'encoder')
        self.metadata['embedding_size'] = (
            self.language_model.config.d_model if hasattr(self.language_model.config, 'd_model')
            else self.language_model.config.hidden_size
        )

    @property
    def name(self):
        return self.metadata["name"]

    def _state_dict(self, model_dir):
        return os.path.join(model_dir, f"{self.name}_state_dict.pkl")

    def load_parameters(self, model_dir: str, accelerator: Accelerator) -> None:
        """ Loads all the parameters in the model from a directory """
        ckpt_path = self._state_dict(model_dir)
        with open(ckpt_path, mode="rb") as model_params:
            state_dict = torch.load(model_params, map_location=accelerator.device)
        self.load_state_dict(state_dict)
        accelerator.print(f"Loading model from : {ckpt_path}")

    def save(self, accelerator: Accelerator) -> None:
        """ Saves all the neural network modules in the model to a directory if the validation
            performance of the model surpasses previous epochs as a pickle of model parameters
            and a yaml fie containing metadata
        """
        ckpt_path = self._state_dict(self.metadata['model_dir'])
        accelerator.save(self.state_dict(), ckpt_path)
        metadata_path = os.path.join(self.metadata['model_dir'], f"{self.name}_metadata.yaml")
        with open(metadata_path, 'w', encoding="utf-8") as f:
            yaml.dump(self.metadata, f)
        accelerator.print(f"Saving model weights to : {ckpt_path}")
        accelerator.print(f"Saving model metadata to : {metadata_path}")

    def _add_new_tokens(self, tokens: List[str]) -> None:
        """ Adds a set of tokens required for sequencing and desequencing graph features to the
            tokenizer and language model
        """
        self.tokenizer.add_tokens(tokens)
        if hasattr(self.language_model, "base_model"):
            self.language_model.base_model.resize_token_embeddings(len(self.tokenizer))
        else:
            self.language_model.resize_token_embeddings(len(self.tokenizer))
        self.metadata['vocab_size'] = len(self.tokenizer)

    def _generate_language(
        self,
        text_sequence: torch.Tensor,
        text_attn_mask: torch.Tensor,
        start_token_id: int,
        do_sample: bool,
        max_new_tokens: int,
        num_beams: int = 1
    ) -> GeneratedText:
        """ Generates an output sequence from the model's language model for each point in a batch.
            Returns the output sequences and the logprobs of each token if requested
        """
        generate = (
            self.language_model.generate if hasattr(self.language_model, 'generate')
            else self.language_model.base_model.generate
        )
        with torch.no_grad():
            input_sequence = torch.cat([
                text_sequence,
                torch.ones_like(text_sequence[:, 0]).unsqueeze(1) * start_token_id
            ], dim = 1) if not self.has_encoder else text_sequence
            input_attn_mask = torch.cat([
                text_attn_mask,
                torch.ones_like(text_sequence[:, 0]).unsqueeze(1)
            ], dim = 1) if not self.has_encoder else text_attn_mask
            sampled_outputs = generate(
                input_sequence,
                attention_mask=input_attn_mask,
                return_dict_in_generate=True,
                max_new_tokens=max_new_tokens,
                generation_config=(GenerationConfig(
                    do_sample=do_sample,
                    num_beams=1 if not do_sample else num_beams,
                    decoder_start_token_id=start_token_id if self.has_encoder else None
                ))
            )
        return GeneratedText(
            ids=sampled_outputs.sequences,
            mask=torch.where(sampled_outputs.sequences == self.tokenizer.pad_token_id, 0, 1).int()
        )

    def get_collate_fn(self):
        """ Returns a batching function for loading data to train and evaluate the model """
        return partial(
            self.text_graph2inputs_batch,
            tokenizer=self.tokenizer,
            randomize=self.metadata['randomize_sequence']
        )

    @abstractmethod
    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """ Performs a forward pass through the model and returns the outputs as a dictionary of
            tensors
        """

    @abstractmethod
    def generate_graph(
        self,
        *,
        text_sequence: torch.Tensor,
        text_attn_mask: torch.Tensor,
        do_sample: bool,
        max_new_tokens: int,
        num_beams: int = 1,
        **kwargs
    ) -> Tuple[List[TextGraph], GeneratedText]:
        """ Generates a graph for each data point in a batch given the points text input """

    @staticmethod
    @abstractmethod
    def text_graph2inputs(
        text_graph_pair: TextGraph,
        randomize: bool = False
    ) -> Dict[str, Union[str, npt.NDArray[np.int_]]]:
        """ Processes a data point's text and graph into a dictionary of tensors with the required
            inputs for the model and returns the dictionary
        """

    @staticmethod
    @abstractmethod
    def text_graph2inputs_batch(
        text_graph_pairs: List[TextGraph],
        tokenizer: PreTrainedTokenizerFast,
        randomize: bool = False,
        has_encoder: bool = False
    ) -> Dict[str, torch.Tensor]:
        """ Processes a batch of text-graph pairs into a dictionary of tensors with the required
            inputs for a model and returns the dictionary
        """

    @staticmethod
    @abstractmethod
    def inputs2graph(
        inputs: Dict[str, torch.Tensor],
        tokenizer: PreTrainedTokenizerFast,
        file_paths: Optional[List[str]] = None
    ) -> List[TextGraph]:
        """ Processes a batch of model outputs, possibly generated, into a batch of graphs """
