import torch
from tqdm import tqdm
import numpy as np
import time

# see the structure of the model
# print(model)

# LlamaForCausalLM(                                                                                                                                                                                                                                                      
#   (model): LlamaModel(                                                                                                                                                                                                                                                 
#     (embed_tokens): Embedding(32000, 4096)                                                                                                                                                                                                                             
#     (layers): ModuleList(                                                                                                                                                                                                                                              
#       (0-31): 32 x LlamaDecoderLayer(                                                                                                                                                                                                                                  
#         (self_attn): LlamaAttention(                                                                                                                                                                                                                                   
#           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (k_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (v_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (rotary_emb): LlamaRotaryEmbedding()                                                                                                                                                                                                                         
#         )                                                                                                                                                                                                                                                              
#         (mlp): LlamaMLP(                                                                                                                                                                                                                                               
#           (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)                                                    
#           (up_proj): Linear(in_features=4096, out_features=11008, bias=False)                                                      
#           (down_proj): Linear(in_features=11008, out_features=4096, bias=False)                                                    
#           (act_fn): SiLU()                                                                                                         
#         )                                                                                                                          
#         (input_layernorm): LlamaRMSNorm()                                                                                          
#         (post_attention_layernorm): LlamaRMSNorm()                                                                                 
#       )                                                                                                                            
#     )                                                                                                                              
#     (norm): LlamaRMSNorm()                                                                                                         
#   ) 
# )

def print_debug(string, is_debug=0):
    if is_debug:
        print(string)

@torch.no_grad()
def stable_rank(A):
    # Compute the stable rank of a matrix A.
    # just get the approximate singular values, don't need U and V
    A = A.float()
    S = torch.linalg.svdvals(A)
    # print(f'S.shape: {S.shape}')
    sr = (S**2).sum() / (S**2).max()
    
    return sr.cpu().numpy()

def device_choser(device=None):
    if device == None:
        # use auto device map, the first layer should be on cuda:0
        device = 'cuda:0'
        print(f'use auto device map, the first layer should be on cuda:0')
    else:
        print(f'use device: {device}')
        
    return device

@torch.no_grad()
def calculate_ppl(model, encodings, stride=512, device=None):
    # using huggingface default setting
    max_length = model.config.max_position_embeddings
    seq_len = encodings.input_ids.size(1)
    print(f'max_length: {max_length}, seq_len: {seq_len}')
    
    nlls = []
    prev_end_loc = 0
    
    device = device_choser(device)
    
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    
    return ppl.item()

# @torch.no_grad()
# def generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False):
#     model.eval()
#     model.to(device)
#     model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
#     generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
#     output = tokenizer.batch_decode(generated_ids)[0]
    
#     return output    

  


# @torch.no_grad()
# def visualize_output(model, device=None, tokenizer=None):
#     model.eval()
#     if device == None:
#         # use auto device map, the first layer should be on cuda:0
#         device = 'cuda:0'
#         print(f'use auto device map, the first layer should be on cuda:0')
#     else:
#         model.to(device) # temporary for seperate_steps.py
#         print(f'use device: {device}')
        
    
#     # print(f'====================================================================')
#     # prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
#     #           "Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
#     #           "there?")
#     # model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
#     # generated_ids = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
#     # output = tokenizer.batch_decode(generated_ids)[0]
#     # print(output)
    
    
#     print(f'====================================================================')
#     prompt = ("A chat between a curious girl and an expert.\n\nGirl: Can you introduce Statue of Liberty for me in 100 words?\nExpert: Sure, "
#               "the Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor in New York City, in the United States. It was designed by ")
#     print(generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False))
    
#     print(f'====================================================================')
#     prompt = ("[Question]: What is the mass correction of a light pseudoscalar decay? [Answer]: The mass correction of a light pseudoscalar decay refers to the effects of the masses of the final state particles on the decay width of a particle. These corrections can be")
#     print(generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False))
    
#     print(f'====================================================================')
#     prompt = ("Answer the following questions. \n\nQuestion: In a park there are 200 animals. Of these anmials, 3/5 are dogs. Of the dogs in the park, 1/5 are puppies. How many puppies are in the park?")
#     print(generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False))
    
#     print(f'====================================================================')
#     prompt = ("Select the correct answer (between A, B, C and D) for the following question and illustrate the reason. \n\nExample:\nQuestion: A signal from the brain to a muscle in the arm is transmitted by which structures?  A. sensory neurons, B. interneurons, C. motor neurons, D. mechanoreceptor neurons\nAnswer: C. motor neurons.\n\nLet's start!\nQuestion: What type of organism is commonly used in preparation of foods such as cheese and yogurt?	 A. viruses, B. protozoa, C. mesophilic organisms D. gymnosperms.\nAnswer: ")
#     print(generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False))
    
#     print(f'====================================================================')
#     prompt = ("Select the correct answer (between A, B, C and D) for the following question and illustrate the reason. \n\nExample:\nQuestion: A signal from the brain to a muscle in the arm is transmitted by which structures?  A. sensory neurons, B. interneurons, C. motor neurons, D. mechanoreceptor neurons\nAnswer: C. motor neurons.\n\nLet's start!\nQuestion: How many times does Earth rotate on its axis in one day?	A. once, B. twice, C. 24 times, D. 365 times\nAnswer: ")
#     print(generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False))
    



@torch.no_grad()
def generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False):
    
    model.eval()
    model.to(device)
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    
    start = time.time()
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
    end = time.time()
    
    output = tokenizer.batch_decode(generated_ids)[0]
    print(f'========================= use_time: {end-start} for the below output =========================')
    print(output)
    
    
    return output, end-start  

@torch.no_grad()
def visualize_output(model, device=None, tokenizer=None, if_short=0):
    model.eval()
    if device == None:
        # use auto device map, the first layer should be on cuda:0
        device = 'cuda:0'
        print(f'use auto device map, the first layer should be on cuda:0')
    else:
        model.to(device) # temporary for seperate_steps.py
        print(f'use device: {device}')
        
    
    
    prompt_list = [
        "A chat between a curious girl and an expert.\n\nGirl: Can you introduce Statue of Liberty for me in 100 words?\nExpert: Sure, the Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor in New York City, in the United States. It was designed by ",
        "[Question]: What is the mass correction of a light pseudoscalar decay? [Answer]: The mass correction of a light pseudoscalar decay refers to the effects of the masses of the final state particles on the decay width of a particle. These corrections can be",
        "Answer the following questions. \n\nQuestion: In a park there are 200 animals. Of these anmials, 3/5 are dogs. Of the dogs in the park, 1/5 are puppies. How many puppies are in the park?",
        "Select the correct answer (between A, B, C and D) for the following question and illustrate the reason. \n\nExample:\nQuestion: A signal from the brain to a muscle in the arm is transmitted by which structures?  A. sensory neurons, B. interneurons, C. motor neurons, D. mechanoreceptor neurons\nAnswer: C. motor neurons.\n\nLet's start!\nQuestion: What type of organism is commonly used in preparation of foods such as cheese and yogurt?	 A. viruses, B. protozoa, C. mesophilic organisms D. gymnosperms.\nAnswer: ",
        "Select the correct answer (between A, B, C and D) for the following question and illustrate the reason. \n\nExample:\nQuestion: A signal from the brain to a muscle in the arm is transmitted by which structures?  A. sensory neurons, B. interneurons, C. motor neurons, D. mechanoreceptor neurons\nAnswer: C. motor neurons.\n\nLet's start!\nQuestion: How many times does Earth rotate on its axis in one day?	A. once, B. twice, C. 24 times, D. 365 times\nAnswer: "
    ]
    
    if if_short:
        prompt_list = prompt_list[:1]
    
    total_time = 0
    
    for prompt in prompt_list:
        _, use_time = generate_output(prompt, tokenizer, device, model, max_new_tokens=200, do_sample=False)
        total_time += use_time
        
    print(f'===> total_time: {total_time}')
    
    
# @torch.no_grad()
# def test_forward(model, device=None, tokenizer=None):
#     model.eval()
#     if device == None:
#         # use auto device map, the first layer should be on cuda:0
#         device = 'cuda:0'
#     else:
#         model.to(device)
        
    
#     print(f'====================================================================')
#     prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
#               "Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
#               "there?")
#     model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
#     generated_ids = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
#     output = tokenizer.batch_decode(generated_ids)[0]
#     print(output)


def set_seed(seed=42):
    ############## set the seed for reproducibility ################
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False