import torch
from torch.nn import CrossEntropyLoss
from transformers import LlamaForCausalLM, AutoModelForCausalLM
from peft import PeftModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from peft import LoraConfig, get_peft_model
from copy import deepcopy
from omegaconf import OmegaConf
import logging
from transformers.generation import GenerationMixin

try:
    from transformers import DynamicCache
except ImportError:
    DynamicCache = None

# Code adapted from ULD repo: https://github.com/UCSB-NLP-Chang/ULD

logger = logging.getLogger("ULD Trainer")

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

def copy_weights(model, original_model):
    assert isinstance(original_model, LlamaForCausalLM), "Only Llama models are supported in ULD."
    name = original_model.config._name_or_path.lower()
    logger.info(f"Copying first {model.config.num_hidden_layers} layers from {name}")
    
    model.model.embed_tokens.load_state_dict(
        original_model.model.embed_tokens.state_dict()
    )
    model.model.norm.load_state_dict(
        original_model.model.norm.state_dict()
    )
    for layer_num in range(model.config.num_hidden_layers):
        model.model.layers[layer_num].load_state_dict(
            original_model.model.layers[layer_num].state_dict()
        )
    model.lm_head.load_state_dict(
        original_model.lm_head.state_dict()
    )
    return model

def get_assistant_model(original_model, lora, num_layers):
    assert isinstance(original_model, LlamaForCausalLM), "Only Llama models are supported in ULD."
    config = deepcopy(original_model.config)
    config.num_hidden_layers = num_layers
    config.lora = OmegaConf.to_container(lora, resolve=True)
    
    logger.info(f"Creating assistant model with {num_layers} layers.")
    logger.info(f"Assistant model LoRA config: {config.lora}")
    logger.info(f"Using dtype: {config.torch_dtype}")
    attn_imp_name = type(original_model.model.layers[0].self_attn).__name__
    use_flash_attention_2 = "FlashAttention2" in attn_imp_name
    
    if use_flash_attention_2:
        try:
            model = AutoModelForCausalLM.from_config(config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2")
            logger.info(f"Setting attention implementation to flash_attention_2 as derived from original model's attention implementation: {attn_imp_name}")
        except TypeError:
            logger.info(f"attn_implementation parameter not supported by from_config(), setting use_flash_attention_2=True instead.")
            model = AutoModelForCausalLM.from_config(config, torch_dtype=config.torch_dtype, use_flash_attention_2=True)
    else:
        logger.info(f"Using default attention implementation for assistant model as derived from original model's attention implementation: {attn_imp_name}")
        model = AutoModelForCausalLM.from_config(config, torch_dtype=config.torch_dtype)

    logger.info(f"Model dtype after init: {model.dtype}")
    logger.info(f"Model attention implemn. after init: {type(model.model.layers[0].self_attn)}")

    # Copy weights from the original model to the new model
    model = copy_weights(model, original_model).to(original_model.device, dtype=config.torch_dtype)

    # Initialize LoRA adapters
    peftconfig = LoraConfig(
        r=lora.rank,
        lora_alpha=lora.alpha,
        target_modules=find_all_linear_names(original_model), 
        lora_dropout=lora.dropout,
        bias=lora.bias, 
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, peftconfig)
    
    # Make sure device and dtype are set
    model = model.to(original_model.device, dtype=config.torch_dtype)
    return model

class ULDModel(torch.nn.Module, GenerationMixin):
    def __init__(self, basellm, assist_llm, tokenizer, weight, top_logit_filter):
        assert isinstance(basellm, LlamaForCausalLM), "ULD currently only supports Llama models."
        super().__init__()
        self.basellm = basellm
        self.assist_llm = assist_llm
        if isinstance(self.assist_llm, PeftModel):
            logger.info("Merging LoRA adapters from assistant model and unloading PEFT wrappers.")
            self.assist_llm.merge_and_unload()
        self.weight = float(weight)
        assert self.weight >= 0.0, "Weight must be non-negative."
        self.device = self.basellm.device
        self.config = self.basellm.config
        self.generation_config = deepcopy(basellm.generation_config)
        for name in ("_eos_token_tensor", "_bos_token_tensor", "_pad_token_tensor"):
            if not hasattr(self.generation_config, name):
                setattr(self.generation_config, name, None)

        self.top_logit_filter = top_logit_filter
        self.tokenizer = tokenizer

        logger.info(
            "Initialized ULDModel (weight=%.3f, top_logit_filter=%.3f).",
            self.weight,
            self.top_logit_filter,
        )

        logger.info(f"Generation config: {self.generation_config}")

    @classmethod
    def from_assistant_model(cls, basellm, assist_llm, tokenizer, **kwargs):
        return cls(basellm, assist_llm, tokenizer, **kwargs)

    def get_loss(self, logits, labels=None, attention_mask=None, reduction='mean'):
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            if reduction == 'batchmean':
                loss_fct = CrossEntropyLoss(reduction='none')
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1)
                shift_labels = shift_labels.to(shift_logits.device)
                loss = loss_fct(shift_logits, shift_labels)
                loss = loss.sum(dim=-1) / (attention_mask.sum(dim=-1))
            else:
                loss_fct = CrossEntropyLoss(reduction=reduction)
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1)
                # Enable model parallelism
                shift_labels = shift_labels.to(shift_logits.device)
                loss = loss_fct(shift_logits, shift_labels)
        return loss

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        #! This forward only returns the logits, never use this for training
        
        labels = kwargs.pop("labels", None)
        kwargs["output_hidden_states"] = False
        kwargs["output_attentions"] = False

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        base_outputs = self.basellm(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        assit_outputs = self.assist_llm(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        
        baselogits = base_outputs.logits
        assist_logits = assit_outputs.logits

        if self.top_logit_filter > 0.0:
            mask = self.relative_top_filter(baselogits, self.top_logit_filter)
        else:
            mask = torch.zeros_like(baselogits, dtype=torch.bool)
            
        logits = baselogits - self.weight * assist_logits
        logits[mask] = -float("Inf") # mask out low-probability tokens according to base model
        
        loss = self.get_loss(logits, labels)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )
    
    def can_generate(self) -> bool:
        return True

    def relative_top_filter(self, logits, top_logit_filter):
        min_tokens_to_keep = int(top_logit_filter * logits.shape[-1])
        log_probs = logits.log_softmax(dim=-1)

        if min_tokens_to_keep > 0:
            min_thresh = torch.topk(log_probs, min_tokens_to_keep, dim=-1).values[..., -1]
        else:
            min_thresh = log_probs.amin(dim=-1)

        probs_max = log_probs.max(dim=-1).values
        log_ratio = log_probs.new_tensor(top_logit_filter).log()
        probs_thresh = torch.minimum(min_thresh, probs_max + log_ratio).unsqueeze(-1)
        return log_probs < probs_thresh

    @torch.no_grad()
    def custom_generate(self, input_ids, generation_config, **kwargs):
        """
        Custom generate that maintains KV cache for basellm and assist_llm
        and samples from base - weight * assist.
        """

        def get_gen_param(name, default=None):
            if name in kwargs:
                return kwargs.pop(name)
            return getattr(generation_config, name, default)

        do_sample = get_gen_param("do_sample", default=False)
        assert not do_sample, "ULDModel.custom_generate currently only supports greedy sampling."

        use_cache = get_gen_param("use_cache", default=False)

        max_new_tokens = get_gen_param("max_new_tokens")
        assert max_new_tokens is not None, "max_new_tokens must be specified for ULDModel.custom_generate."

        eos_token_id = get_gen_param("eos_token_id")
        if eos_token_id is not None:
            if not isinstance(eos_token_id, list):
                eos_token_id = [eos_token_id]
            eos_token_id = torch.tensor(eos_token_id, dtype=torch.long, device=input_ids.device)

        attention_mask = get_gen_param("attention_mask")
        if attention_mask is None:
            logger.warning("No attention_mask provided, initializing to ones.")
            attention_mask = torch.ones_like(input_ids)

        if use_cache:
            return self._custom_generate_with_cache(input_ids, attention_mask, max_new_tokens, eos_token_id)
        return self._custom_generate_no_cache(input_ids, attention_mask, max_new_tokens, eos_token_id)

    def _custom_generate_no_cache(self, input_ids, attention_mask, max_new_tokens, eos_token_id):
        batch_size = input_ids.shape[0]
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

        for _ in range(max_new_tokens):
            base_out = self.basellm(input_ids=input_ids, attention_mask=attention_mask)
            assist_out = self.assist_llm(input_ids=input_ids, attention_mask=attention_mask)

            base_logits = base_out.logits[:, -1, :]
            assist_logits = assist_out.logits[:, -1, :]

            if self.top_logit_filter > 0:
                mask = self.relative_top_filter(base_logits, self.top_logit_filter)
            else:
                mask = torch.zeros_like(base_logits, dtype=torch.bool)

            logits = base_logits - self.weight * assist_logits
            logits[mask] = -float("Inf")

            next_input_ids = torch.argmax(logits, dim=-1, keepdim=True)

            input_ids = torch.cat([input_ids, next_input_ids], dim=-1)
            attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)], dim=-1)

            if eos_token_id is not None:
                is_eos = torch.isin(input_ids[:, -1], eos_token_id)
                unfinished_sequences = unfinished_sequences * (~is_eos).long()
                if unfinished_sequences.max() == 0:
                    break

        return input_ids

    def _custom_generate_with_cache(self, input_ids, attention_mask, max_new_tokens, eos_token_id):
        batch_size = input_ids.shape[0]
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
        generated = input_ids

        if DynamicCache is not None:
            base_past = DynamicCache()
            assist_past = DynamicCache()
        else:
            base_past = None
            assist_past = None

        current_input_ids = input_ids
        current_attention_mask = attention_mask
        mask_dtype = attention_mask.dtype
        mask_device = attention_mask.device

        for _ in range(max_new_tokens):
            base_out = self.basellm(
                input_ids=current_input_ids,
                attention_mask=current_attention_mask,
                past_key_values=base_past,
                use_cache=True,
            )
            assist_out = self.assist_llm(
                input_ids=current_input_ids,
                attention_mask=current_attention_mask,
                past_key_values=assist_past,
                use_cache=True,
            )

            base_logits = base_out.logits[:, -1, :]
            assist_logits = assist_out.logits[:, -1, :]

            if self.top_logit_filter > 0:
                mask = self.relative_top_filter(base_logits, self.top_logit_filter)
            else:
                mask = torch.zeros_like(base_logits, dtype=torch.bool)

            logits = base_logits - self.weight * assist_logits
            logits[mask] = -float("Inf")

            next_input_ids = torch.argmax(logits, dim=-1, keepdim=True)

            generated = torch.cat([generated, next_input_ids], dim=-1)

            base_past = base_out.past_key_values
            assist_past = assist_out.past_key_values
            current_input_ids = next_input_ids
            current_attention_mask = torch.ones((batch_size, 1), dtype=mask_dtype, device=mask_device)

            if eos_token_id is not None:
                is_eos = torch.isin(next_input_ids.squeeze(-1), eos_token_id)
                unfinished_sequences = unfinished_sequences * (~is_eos).long()
                if unfinished_sequences.max() == 0:
                    break

        return generated

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return self.basellm.prepare_inputs_for_generation(input_ids, **kwargs)
    
    def generate(self, inputs=None, **kwargs):
        return self.custom_generate(inputs, self.generation_config, **kwargs)