from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from math import log


class Generator:
    def __init__(self, model, answer_extractor, n, temperature):
        self.n = n
        if n and temperature is None:
            temperature = self.compute_temperature(n)
        self.temperature = temperature

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.answer_extractor = answer_extractor
        self.max_new_tokens = 1024
        self.generated_sequences = 0
        self.generated_tokens = 0

    @staticmethod
    def compute_temperature(n):
        assert n > 0
        return min(log(n) / log(64), 1)

    def inference_cost(self):
        return self.generated_sequences, self.generated_tokens


class FastGenerator(Generator): # uses VLLM, cannot output hidden states
    def __init__(self, model, answer_extractor, n=None, temperature=None):
        super().__init__(model, answer_extractor, n, temperature)
        self.model = LLM(model=model, gpu_memory_utilization=0.6, swap_space=4)

    def __call__(self, input_text, n=None, with_full_outputs=False):
        if n is None:
            assert self.n
            n, temperature = self.n, self.temperature
            print(f'using default number of samples {self.n}')
            print(f'using default temperature {self.temperature}')
        else:
            assert n > 0
            temperature = self.compute_temperature(n)
            print(f'overwriting default number of samples ({self.n}) with {n}')
        print(f'generating (fast) {n} samples with temperature {temperature}')
        sampling_params = SamplingParams(
            n=n, temperature=temperature,
            max_tokens=self.max_new_tokens)
        output = self.model.generate(input_text, sampling_params)
        assert len(output) == 1
        outputs = output[0].outputs # n samples
        output_texts = [o.text for o in outputs]
        print('updating generations stats...')
        self.generated_sequences += len(output_texts)
        self.generated_tokens += sum([len(self.tokenizer(t)['input_ids']) for t in output_texts])
        answers = [self.answer_extractor(o) for o in output_texts]
        if with_full_outputs:
            return output_texts, answers
        return answers


class SlowGenerator(Generator): # slow, based on HuggingFace's generate, but outputs hidden states
    def __init__(self, model, answer_extractor, n, temperature=None):
        super().__init__(model, answer_extractor, n, temperature)
        assert self.n > 0
        self.model = AutoModelForCausalLM.from_pretrained(model, device_map="auto")
        self.do_sample = self.temperature > 0
        self.hidden_states_layers = [-1, -2, -4, -8, -16]
        self.hidden_states_tokens = range(-16,0) # last 16 tokens
        self.hidden_states_rounding = 2

    def process_output(self, output):
        hidden_states = []
        assert len(output.sequences) == 1
        # output.sequences.shape == (1, n_tokens of input + output)
        # len(output.hidden_states)) == n_tokens of output
        output_text = self.tokenizer.decode(output.sequences[0])
        answer = self.answer_extractor(output_text)
        for layer_num in self.hidden_states_layers:
            # https://huggingface.co/docs/transformers/v4.47.1/en/main_classes/output#transformers.utils.ModelOutput
            hidden_states_layer = [hs[layer_num][0,-1,:].tolist() for hs in output.hidden_states]
            # TODO make it faster by not converting to lists?
            for token_num in self.hidden_states_tokens:
                try:
                    hs_lt = [round(h, self.hidden_states_rounding) for h in hidden_states_layer[token_num]]
                    hidden_states.append([layer_num, token_num] + hs_lt)
                except:
                    print(f'failed to extract a hidden state for token {token_num}')
                    print(f'length of the sequence: {len(hidden_states_layer)}')
        print('updating generations stats...')
        self.generated_sequences += 1
        self.generated_tokens += len(output.sequences[0])
        # TODO should you include both the input and output tokens?
        return answer, output_text, hidden_states

    def __call__(self, input_text, with_full_outputs=False):
        print(f'generating (slowly) {self.n} samples with temperature {self.temperature}')
        input_tokens = self.tokenizer(input_text, return_tensors="pt").to('cuda')
        answers = []
        output_texts = []
        hidden_states = []
        for i in range(self.n): # make it faster using num_return_sequences in model.generate()?
            output = self.model.generate(
                **input_tokens,
                max_new_tokens=self.max_new_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                do_sample=self.do_sample,
                temperature=self.temperature,
            )
            answer, output_text, hidden_states_from_one_output = self.process_output(output)
            answers.append(answer)
            output_texts.append(output_text)
            hidden_states.append(hidden_states_from_one_output)
        if with_full_outputs:
            return output_texts, answers, hidden_states
        return answers, hidden_states


class VariedTemperature(SlowGenerator):
    def __init__(self, model, answer_extractor, n, max_temperature):
        assert 0 < max_temperature
        assert 0 < n
        super().__init__(model, answer_extractor, n, None)
        temperature_step = 0 if n == 1 else max_temperature / (n - 1)
        self.temperature_list = [max_temperature - (i * temperature_step) for i in range(n)]

    def __call__(self, input_text):
        print(f'generating (slowly) {self.n} samples with varied temperature')
        input_tokens = self.tokenizer(input_text, return_tensors="pt").to('cuda')
        answers = []
        hidden_states = []
        for temperature in self.temperature_list:
            print(f'generating with temperature {temperature}')
            do_sample = temperature > 0
            output = self.model.generate(
                **input_tokens,
                max_new_tokens=self.max_new_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                do_sample=do_sample,
                temperature=temperature,
            )
            answer, _, hidden_states_from_one_output = self.process_output(output)
            answers.append(answer)
            hidden_states.append(hidden_states_from_one_output)
        return answers, hidden_states
