import torch
from models.llama_kivi import LlamaForCausalLM_KIVI
from models.mistral_kivi import MistralForCausalLM_KIVI
from models.llama_sinkq import LlamaForCausalLM_SinkQ
from models.mistral_sinkq import MistralForCausalLM_SinkQ
from transformers import AutoTokenizer,AutoConfig, AutoModelForCausalLM
from transformers.generation.utils import *
import argparse
import json

def parse_args():
    parse = argparse.ArgumentParser()
    parse.add_argument("--model_path", type=str, default=None, help="model path")
    parse.add_argument("--model_type", type=str, default="llama", help="model type",choices=["llama","mistral"])
    parse.add_argument("--method", type=str, default='SinkQ', help="quant method",choices=["FP16","KIVI","SinkQ"])
    # quant hyper parameters
    parse.add_argument("--k_bits", type=int, default=2)
    parse.add_argument("--v_bits", type=int, default=2)
    parse.add_argument("--group_size", type=int, default=128)
    parse.add_argument("--residual_length", type=int, default=128)
    parse.add_argument("--sink_num", type=int, default=3)
    parse.add_argument("--sink_max_size", type=int, default=32)
    args = parse.parse_args()
    return args

def greedy_search(
    self,
    input_ids: torch.LongTensor,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    max_length: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_scores: Optional[bool] = None,
    return_dict_in_generate: Optional[bool] = None,
    synced_gpus: bool = False,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
    # init values
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
    output_attentions = (
        output_attentions if output_attentions is not None else self.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
    )
    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else self.generation_config.return_dict_in_generate
    )

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # keep track of which sequences are already finished
    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

    this_peer_finished = False  # used by synced_gpus only
    steps=0
    step_memory=[]
    torch.cuda.reset_peak_memory_stats()
    while True:
        try:
            steps+=1
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            with torch.no_grad():
                # forward pass to get next token
                outputs = self(
                    **model_inputs,
                    return_dict=True,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                )
                if steps%10==0:
                    used_mem = torch.cuda.max_memory_allocated()/ 1024 ** 3
                    step_memory.append({"step":steps,"memory":used_mem})
                if steps%100==0:
                    print(f"step: {steps}     memory: {used_mem}")
            if synced_gpus and this_peer_finished:
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_tokens_scores = logits_processor(input_ids, next_token_logits)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_tokens_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # argmax
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            if streamer is not None:
                streamer.put(next_tokens.cpu())
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )

            # if eos_token was found in one sentence, set sentence to finished
            # if eos_token_id_tensor is not None:
            #     unfinished_sequences = unfinished_sequences.mul(
            #         next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            #     )

            #     # stop when each sentence is finished
            #     if unfinished_sequences.max() == 0:
            #         this_peer_finished = True

            # stop if we exceed the maximum length
            if stopping_criteria(input_ids, scores):
                this_peer_finished = True

            if this_peer_finished and not synced_gpus:
                break
        except:
            break
    print("="*80)
    with open("long_sequence_output.json", 'w') as json_file:
        json.dump(step_memory, json_file, indent=4)
    print("Finished! Logs are saved in long_sequence_output.json")
    print("="*80)
    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return GreedySearchEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            return GreedySearchDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    else:
        return input_ids


def main(args):
    GenerationMixin.greedy_search=greedy_search
    config =AutoConfig.from_pretrained(args.model_path)
    config.k_bits = args.k_bits
    config.v_bits = args.k_bits
    config.group_size = args.group_size
    config.residual_length = args.residual_length
    config.sink_num=args.sink_num
    config.sink_max_size=args.sink_max_size
    use_fast=args.model_type=="llama"
    tokenizer=AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=use_fast)
    
    if args.method=="KIVI" and args.model_type=="llama":
        model = LlamaForCausalLM_KIVI.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="KIVI" and args.model_type=="mistral":
        model = MistralForCausalLM_KIVI.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="SinkQ" and args.model_type=="llama":
        model = LlamaForCausalLM_SinkQ.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="SinkQ" and args.model_type=="mistral":
        model = MistralForCausalLM_SinkQ.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")

   
    inputs = tokenizer("t", return_tensors="pt").to('cuda')
    print("start testing max length...")
    outputs = model.generate(**inputs,do_sample=False)
    
if __name__=="__main__":
    args=parse_args()
    main(args)
