import torch
from torch import nn
from transformers import GenerationMixin, StoppingCriteria, StoppingCriteriaList, AutoModelForCausalLM, AutoTokenizer
from utils import calculate_prompt_length
import time
import copy


class EOSStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id):
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids, scores, **kwargs):
        return (input_ids[0, -1] == self.eos_token_id).item()


class LlamaBiLDModel(nn.Module, GenerationMixin):
    def __init__(
        self,
        large,
        small,
        large_tokenizer,
        small_tokenizer,
        num_small_iters=10,
        fallback_threshold=0.6,
        rollback_threshold=5.0,
    ):
        super(LlamaBiLDModel, self).__init__()
        self.large = large
        self.small = small
        self.large_tokenizer = large_tokenizer
        self.small_tokenizer = small_tokenizer
        self.num_large_iters = 1
        self.num_small_iters = num_small_iters
        self.init_iters(model_kwargs={}, init_with="large")
        self.decoder = self.get_decoder()
        self.main_input_name = self.large.main_input_name
        self.fallback_threshold = fallback_threshold
        self.rollback_threshold = rollback_threshold
        self.crossentropy_loss = nn.CrossEntropyLoss(reduction="none")
        self.generation_config = self.large.generation_config

    @classmethod
    def from_pretrained(
        cls,
        large_model_name_or_path,
        small_model_name_or_path,
        large_tokenizer_name_or_path,
        small_tokenizer_name_or_path,
        *model_args,
        **kwargs,
    ):
        large_model = AutoModelForCausalLM.from_pretrained(
            large_model_name_or_path, *model_args, **kwargs
        )
        small_model = AutoModelForCausalLM.from_pretrained(
            small_model_name_or_path, *model_args, **kwargs
        )
        large_tokenizer = AutoTokenizer.from_pretrained(large_tokenizer_name_or_path)
        small_tokenizer = AutoTokenizer.from_pretrained(small_tokenizer_name_or_path)
        return cls(large_model, small_model, large_tokenizer, small_tokenizer)

    def get_decoder(self):
        return self.large if self.is_large() else self.small

    def get_tokenizer(self):
        return self.large_tokenizer if self.is_large() else self.small_tokenizer


    def init_iters(self, model_kwargs=None, init_with='large'):
        assert init_with in ['large', 'small']
        self.model_type = init_with
        self.iter_count = self.num_large_iters

        self.large_kwargs = {}
        self.small_kwargs = {}

        # Set the initial model_kwargs
        self.model_kwargs = self.large_kwargs if init_with == 'large' else self.small_kwargs


    def schedule_iters(self, fall_back_to_large=False, fall_back_to_small=False):
        self.iter_count -= 1
        # print("Self iter count", self.iter_count)
        to_small = self.is_large() and (self.iter_count == 0 or fall_back_to_small)
        to_large = not self.is_large() and (self.iter_count == 0 or fall_back_to_large)
        if to_small:
            self.model_type = 'small'
            # print("Model type is small and num small iters is", self.num_small_iters)
            self.iter_count = self.num_small_iters
            # print("After changing iters is", self.iter_count)
            self.model_kwargs = self.small_kwargs
        if to_large:
            self.model_type = 'large'
            self.iter_count = self.num_large_iters
            self.model_kwargs = self.large_kwargs
        # print("Model type", self.model_type)

    def forward(self, input_ids=None, attention_mask=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs):
        if self.is_large():
            return self.large(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs
            )
        else:
            # input_ids = input_ids.to('cpu')
            # attention_mask = attention_mask.to('cpu') if attention_mask is not None else None
            return self.small(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs
            )

    @property
    def config(self):
        return self.large.config

    def is_large(self):
        return self.model_type == 'large'

    def resize_token_embeddings(self, n):
        self.large.resize_token_embeddings(n)
        self.small.resize_token_embeddings(n)

    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, head_mask=None, use_cache=None, **kwargs):
        if past is not None:
            previous_generated_len = past[0][0].shape[2]
            input_ids = input_ids[:, previous_generated_len:]
        return {"input_ids": input_ids, "past_key_values": past, "attention_mask": attention_mask, "use_cache": use_cache}

    def _reset_kwargs_past_to_new_length(self, new_len):
        for kwargs in [self.large_kwargs, self.small_kwargs]:
            new_kwargs = []
            for layer_past in kwargs['past']:
                new_layer_kwargs = []
                for i, past in enumerate(layer_past):
                    if i < 2:
                        new_layer_kwargs.append(past[:, :, :new_len - 1, :])
                    else:
                        new_layer_kwargs.append(past)
                new_kwargs.append(tuple(new_layer_kwargs))
            kwargs['past'] = tuple(new_kwargs)

    def _greedy_search_body(self, input_ids, model_kwargs, output_attentions, output_hidden_states, stopping_criteria, pad_token_id, eos_token_id, synced_gpus, unfinished_sequences, prompt_length, max_length, temperature):

        # Initialize large model iterations
        self.init_iters(model_kwargs=model_kwargs, init_with='large')
        scores = []
        self.rollback_signal = None
        generation_times = []
        streamed_token_list = []
        model_types = []  # Track which model generates each token
        token_confidences = []  # List to store confidence scores for each token
        last_prediction_was_same = False
        # print("-----Prompt length is", prompt_length)
        p=0.8


        while True:
            self.model_kwargs['past_key_values'] = None
            if self.rollback_signal:
                new_len = input_ids.shape[-1]
                if model_inputs.get('past_key_values') is not None:
                    self._reset_kwargs_past_to_new_length(new_len)
                # else:
                    # Handle the case where past key values are None
                    # print("No past key values to reset.")

                # self._reset_kwargs_past_to_new_length(new_len)
                self.rollback_signal = None

            model_inputs = self.prepare_inputs_for_generation(input_ids, **self.model_kwargs)
            start_time = time.time()
            outputs = self(
                input_ids=model_inputs["input_ids"],
                past_key_values=model_inputs.get("past_key_values"),
                attention_mask=model_inputs.get("attention_mask"),
                use_cache=model_inputs.get("use_cache"),
                output_attentions=model_inputs.get("output_attentions"),
                output_hidden_states=model_inputs.get("output_hidden_states"),
                return_dict=model_inputs.get("return_dict"),
                **model_inputs.get("kwargs", {})
            )
            if not self.is_large():
                last_token_logits = outputs.logits[:, -1, :]  # Shape: [batch_size, vocab_size]
                scores = torch.softmax(last_token_logits, dim=-1)
                # print(len(input_ids[0]), len(outputs.logits.argmax(dim=-1)[0]))
                next_tokens = torch.argmax(scores, dim=-1)
                # print("Small", next_tokens)
                output_predicted_token_ids = torch.cat((input_ids, torch.tensor([[next_tokens]]).to("cuda")), dim=1) # Get the token with the highest probability
            else:
                # Top-p sampling with temperature
                new_logits = outputs.logits / temperature
                last_token_logits = new_logits[:, -1, :]  # Shape: [batch_size, vocab_size]
                sorted_logits, sorted_indices = torch.sort(last_token_logits, descending=True)
                cumulative_probs = torch.cumsum(nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
                top_p_mask = cumulative_probs < p
                top_p_mask = torch.cat((torch.tensor([[True]]).to("cuda"), top_p_mask[:, :-1]), dim=1)
                filtered_logits = sorted_logits.masked_fill(~top_p_mask, float('-inf'))
                reverted_logits = torch.empty_like(filtered_logits)
                reverted_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)
                probabilities = nn.functional.softmax(reverted_logits, dim=-1)
                next_tokens = torch.multinomial(probabilities, num_samples=1)
                # print("Large", next_tokens)

                # Sample from the probability distribution instead of taking the argmax
                # torch.multinomial samples based on the given probabilities
                output_predicted_token_ids = torch.cat((input_ids, next_tokens), dim=1)  # Shape: [batch_size, 1]
                next_tokens = next_tokens[0]
                # print(len(input_ids[0]), len(torch.multinomial(probabilities, num_samples=1)[0]))
                # Get the logits for the last token only (for auto-regressive generation)
                # print(self.get_tokenizer().decode(output_predicted_token_ids[0], skip_special_tokens=True))

            # # Decode the predicted token IDs into human-readable text
            # decoded_output = self.get_tokenizer().decode(output_predicted_token_ids[0], skip_special_tokens=True)

            # Print the decoded output
            # print("Decoded output from output in generation loop:", decoded_output)

            # decoded_tokens = [self.get_tokenizer().decode([token_id], skip_special_tokens=True) for token_id in output_predicted_token_ids[0]]
            decoded_tokens = [
                self.get_tokenizer().decode([token_id], skip_special_tokens=True)
                for token_id in output_predicted_token_ids[0][prompt_length:]
            ]
            # print(self.get_tokenizer().decode(next_tokens[0], skip_special_tokens=True))

            end_time = time.time()
            total_time = end_time - start_time

            # Distribute the total_time equally among the decoded tokens
            time_per_token = total_time / max(len(decoded_tokens), 1)

            for _ in decoded_tokens:
                generation_times.append(time_per_token)
                model_types.append(self.model_type)

            next_token_logits = outputs.logits[:, -1, :]
            scores = torch.softmax(next_token_logits, dim=-1)
            # next_tokens = torch.argmax(scores, dim=-1)
            # print(next_tokens)

            for i in range(next_tokens.size(0)):
                token_id = next_tokens[i].item()
                token_confidence = scores[i, token_id].item()
                token_confidences.append((self.get_tokenizer().decode(token_id), token_confidence))

            # Fallback condition
            fallback_cond = (
                self.model_type == 'small' and scores.max().item() < self.fallback_threshold
            )

            if fallback_cond:
                # Fallback to large model if small model's confidence is low
                # print("---------Small unsure and using large instead")
                self.schedule_iters(fall_back_to_large=True)
                continue

            if self.is_large():
                large_model_logits = outputs.logits[0, :, :]
                if prompt_length != outputs.logits.shape[1] - 1:
                    small_model_prediction = input_ids[0]
                    # large_model_prediction = large_model_logits.argmax(-1)
                    # large_model_soft_labels = nn.functional.softmax(large_model_logits, dim=-1)
                    decoded_small_model_pred = self.get_tokenizer().decode(small_model_prediction, skip_special_tokens=True)
                    # print("Small model predictions in rollback are",  decoded_small_model_pred)

                    # Calculate loss for the entire sequence
                    # loss = self.crossentropy_loss(large_model_soft_labels, small_model_prediction)

                    loss = self.crossentropy_loss(large_model_logits, small_model_prediction)
                    loss = loss[:-1]
                    loss_above_thres = loss > self.rollback_threshold

                    # Skip rollback if predictions are the same
                    if last_prediction_was_same:
                        last_prediction_was_same = False  # Reset after skipping
                        continue

                    if loss_above_thres.any() and input_ids.shape[1] > prompt_length:
                        # Adjust the index calculation to exclude the prompt tokens
                        losses_after_last_above_thres = loss_above_thres[prompt_length:]
                        if losses_after_last_above_thres.numel() > 0:
                            first_idx_loss_above_thres = losses_after_last_above_thres.to(torch.int).argmax() + prompt_length
                            # new_len = first_idx_loss_above_thres
                            # new_input_ids = input_ids[:, :new_len+1]

                            # Decode the small model prediction (before the replacement)
                            small_model_prediction = input_ids[0]
                            decoded_small_model_pred = self.get_tokenizer().decode(small_model_prediction, skip_special_tokens=True)
                            # print("--- In rollback the cut back input ids now are", decoded_small_model_pred)

                            # Get the new prediction from the large model
                            new_pred = nn.functional.softmax(
                                large_model_logits[first_idx_loss_above_thres:first_idx_loss_above_thres+1, :],
                                dim=-1,
                            ).argmax(-1).unsqueeze(0)

                            # Decode the new predicted token
                            decoded_new_pred = self.get_tokenizer().decode([new_pred.item()])
                            # print(f"First index above threshold: {first_idx_loss_above_thres}")
                            # print(f"New predicted token: {decoded_new_pred}")

                            # Compare the new prediction with the last token above threshold
                            token_above_thres = input_ids[0, first_idx_loss_above_thres + 1]
                            decoded_error_token = self.get_tokenizer().decode([token_above_thres.item()])
                            # print(f"Error token: {decoded_error_token}")

                            if decoded_new_pred != decoded_error_token:
                                # If they are different, replace the last token with the new prediction
                                # print("------Different prediction. Replacing last token with new prediction.")
                                new_len = first_idx_loss_above_thres + 1 # For the same reason as the token_above_thres: first_idx_above_thres comes from loss which is shifted
                                new_input_ids = input_ids[:, :new_len]
                                input_ids = torch.cat([new_input_ids[:, :-1], new_pred], dim=-1)
                                self.rollback_signal = True
                                decoded_pos = min(max_length, first_idx_loss_above_thres + 1)
                                prev_decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length-1:decoded_pos], skip_special_tokens=True)
                                decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length:decoded_pos], skip_special_tokens=True)
                                if len(prev_decoded_text) - len(decoded_text) - 1 >= 0 and len(prev_decoded_text) > 0 and prev_decoded_text[len(prev_decoded_text) - len(decoded_text) - 1] == " ":
                                    decoded_text = " " + decoded_text
                                if len(decoded_text) > 0:
                                    streamed_token_list.append(decoded_text)
                                # print(f"New Streamed Tokens at error: {decoded_text}")
                            else:
                                # print("------Same prediction. No replacement needed.")
                                last_prediction_was_same = True
                                decoded_pos = min(max_length, first_idx_loss_above_thres+1)
                                prev_decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length-1:decoded_pos], skip_special_tokens=True)
                                decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length:decoded_pos], skip_special_tokens=True)
                                if len(prev_decoded_text) - len(decoded_text) - 1 >= 0 and len(prev_decoded_text) > 0 and prev_decoded_text[len(prev_decoded_text) - len(decoded_text) - 1] == " ":
                                    decoded_text = " " + decoded_text
                                if len(decoded_text) > 0:
                                    streamed_token_list.append(decoded_text)
                                # print(f"New Streamed Tokens at same prediction: {decoded_text}")
                            prompt_length = first_idx_loss_above_thres + 1 # Start next checks from that place

                            # Decode the final sequence after adding the new prediction
                            # decoded_final_sequence = self.get_tokenizer().decode(input_ids[0], skip_special_tokens=True)
                            # print("--- In rollback after adding the new prediction, input is", decoded_final_sequence)
                        else:
                            raise Exception("Unexpected array length for losses_after_last_above_thres")
                else:
                    prev_decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length-1:], skip_special_tokens=True)
                    decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length:], skip_special_tokens=True)
                    if len(prev_decoded_text) - len(decoded_text) - 1 >= 0 and len(prev_decoded_text) > 0 and prev_decoded_text[len(prev_decoded_text) - len(decoded_text) - 1] == " ":
                        decoded_text = " " + decoded_text
                    if len(decoded_text) > 0:
                        streamed_token_list.append(decoded_text)
                    prompt_length += 1

            self.model_kwargs = self._update_model_kwargs_for_generation(
                outputs, self.model_kwargs, is_encoder_decoder=False
            )

            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            # if stopping_criteria(input_ids, scores):
            #     break

            # Enforce max_length limit
            if prompt_length >= max_length:
                input_ids = input_ids[:, :max_length]
                break

            self.schedule_iters()

        if prompt_length < len(input_ids[0]):
            prev_decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length-1:], skip_special_tokens=True)
            decoded_text = self.get_tokenizer().decode(input_ids[0][prompt_length:], skip_special_tokens=True)
            if len(prev_decoded_text) - len(decoded_text) - 1 >= 0 and len(prev_decoded_text) > 0 and prev_decoded_text[len(prev_decoded_text) - len(decoded_text) - 1] == " ":
                decoded_text = " " + decoded_text
            if len(decoded_text) > 0:
                streamed_token_list.append(decoded_text)
            print(f"New Streamed Tokens at the end: {decoded_text}")
        
        # print(model_types)

        return input_ids, generation_times, model_types, token_confidences, streamed_token_list

    
    def generate(self, input_ids, max_length=None, max_new_tokens=None, temperature=None, **kwargs):
        model_kwargs = kwargs.copy()
        stopping_criteria = StoppingCriteriaList(
            [EOSStoppingCriteria(eos_token_id=self.get_tokenizer().eos_token_id)]
        )

        prompt_length = input_ids.shape[1]

        # Determine max_length based on max_new_tokens
        if max_new_tokens is not None:
            max_length = prompt_length + max_new_tokens
        elif max_length is None:
            max_length = self.generation_config.max_length

        return self._greedy_search_body(
            input_ids=input_ids,
            model_kwargs=model_kwargs,
            output_attentions=model_kwargs.pop("output_attentions", None),
            output_hidden_states=model_kwargs.pop("output_hidden_states", None),
            stopping_criteria=stopping_criteria,
            pad_token_id=model_kwargs.pop(
                "pad_token_id", self.generation_config.pad_token_id
            ),
            eos_token_id=model_kwargs.pop(
                "eos_token_id", self.generation_config.eos_token_id
            ),
            synced_gpus=model_kwargs.pop("synced_gpus", False),
            unfinished_sequences=torch.ones(
                input_ids.shape[0], dtype=torch.long, device=input_ids.device
            ),
            prompt_length=prompt_length,
            max_length=max_length,
            temperature=temperature
        )

    def can_generate(self):
        return True

    @property
    def device(self):
        return next(self.parameters()).device


