# -*- coding:utf-8 -*-
"""
create an actor model
"""

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from utils.agent_memory import memory_initialize, memory_initialize_from_file


def log_probs_from_logits(logits, labels):
    """
    According to the logits and labels predicted by the model, get the corresponding label position probability
    :param logits:
    :param labels:
    :return:
    """
    probs = torch.log_softmax(logits, dim=-1)

    probs_labels = probs.gather(dim=-1, index=labels.unsqueeze(-1))
    probs_labels = probs_labels.squeeze(-1)

    return probs_labels


class ActorModel(torch.nn.Module):
    """
    Actor Model：finetuned LLM in our case
    """

    def __init__(self, model_path, prompt, tool_database=None, memory_root_path=None, model_name=None,
                 require_grad=True, device='auto',
                 resume=False):
        super().__init__()

        self.require_grad = require_grad
        self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, torch_dtype=torch.bfloat16)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        print("actor.model.device: ", self.model.hf_device_map, self.model.device)

        if memory_root_path is not None:
            self.memory_json_path = memory_root_path + '/' + model_name + '/' + 'agent_memory.json'
            self.memory_delete_json_path = memory_root_path + '/' + model_name + '/' + 'agent_delete_memory.json'
            self.memory_vector_path = memory_root_path + '/' + model_name + '/' + 'agent_memory.npy'
            self.memory_faiss_path = memory_root_path + '/' + model_name + '/' + 'agent_memory.index'
            # If you are resuming training from a checkpoint, memory does not need to be initialized.
            if not resume:
                memory_initialize(self.memory_json_path, self.memory_delete_json_path, self.memory_vector_path,
                                  self.memory_faiss_path)

        self.prompt = prompt
        # The name of the library used for tool_retrieval (not available for math and coding agents)
        self.tool_database = tool_database

        self.model.gradient_checkpointing_enable()  # This step is very important. Otherwise, for large models, it is easy to cause OOM after max_length becomes longer.
        self.model.enable_input_require_grads()

    def generate(self, input_ids, **gen_kwargs):
        """
        Generate an answer based on the input query (intention or agent scheduling)
        """
        input_ids = input_ids.to(self.model.device)
        outputs = self.model.generate(input_ids=input_ids, **gen_kwargs)
        pad_token_id = gen_kwargs.get('pad_token_id', None)
        attention_mask = outputs.not_equal(pad_token_id).to(dtype=torch.long, device=outputs.device)
        return outputs, attention_mask

    def forward(self, input_ids, attention_mask=None):
        """
        Generate logits based on the input query + answer
        """
        input_ids = input_ids.to(self.model.device)
        if self.require_grad:
            logits = self.model(input_ids, attention_mask=attention_mask).logits
            log_probs = log_probs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
            return log_probs
        else:
            with torch.no_grad():
                logits = self.model(input_ids, attention_mask=attention_mask).logits
                log_probs = log_probs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
                return log_probs

    def save_pretrained(self, model_path, model_name='actor_model.pth', save_weights_only=True):
        if save_weights_only:
            torch.save(self.model.state_dict(), model_path + '/weights_' + model_name)
        else:
            torch.save(self.model, model_path + '/full_' + model_name)

