# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
from src.data_utils import get_dataset
from src.model_utils import  print_model_summary,load_safetensors_model,print_model_summary,save_model
from tqdm import tqdm
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from eval.evaluate_mmlu import eval_mmlu
from eval.evaluate_mnli import eval_mnli
from eval.evaluate_qnli import eval_qnli
from eval.evaluate_flops import eval_flops
from eval.evaluate_advglue import eval_advglu
from eval.evaluate_tQA import eval_tQA

import torch
import torch.nn as nn
from tqdm import tqdm

def generate_text(model, tokenizer, input_text, max_length=50):
    # Encode input text
    embedding_dim = model.config.hidden_size  # or model.embed_tokens.weight.shape[1]
    print("Model embedding dimension:", embedding_dim)
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)
    print("input_ids.shape:", input_ids.shape)
    # Get model output
    with torch.no_grad():
        outputs = model.generate(input_ids, max_length=max_length, num_return_sequences=1)

    # Decode generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text


def evaluate_model_performance(model, tokenizer, input_text, device):
    # Encode input text
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)

    # Use torch.profiler for performance evaluation
    with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            profile_memory=True,
            with_flops=True
    ) as prof:
        # Perform forward pass
        output = model(input_ids)

    return prof


@torch.no_grad()
def calculate_perplexity(model, tokenizer, dataset: str, max_length: int, stride: int = 2048) -> float:
    print("Loading dataset...")
    encodings = get_dataset(dataset, tokenizer)
    seq_len = encodings.size(1)

    print("Calculating perplexity...")
    print(f"Sequence length: {seq_len}")
    print(f"Max length: {max_length}")
    print(f"Stride: {stride}")

    nlls = []
    prev_end_loc = 0
    recorded_samples = []  # Store samples that meet condition
    for begin_loc in (pbar := tqdm(range(0, seq_len - 1, stride))):
        end_loc = min(seq_len - 1, begin_loc + max_length)
        trg_len = end_loc - prev_end_loc  # Number of tokens to predict
        input_ids = encodings[:, begin_loc:end_loc + 1].to('cuda')  # +1 for labels

        with torch.no_grad():
            # Get logits from model
            # NOTE: Call model directly to reduce memory overhead
            outputs = model.model(input_ids[:, :-1], use_cache=False)
            logits = model.lm_head(outputs[0][..., -trg_len:, :])

            # Labels are last trg_len tokens
            labels = input_ids[:, -trg_len:].contiguous()

            # Compute NLL loss for batch
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        nlls.append(loss.to('cpu').to(torch.float32))
        ppl = torch.exp(torch.stack(nlls).mean())
        pbar.set_description(f"Perplexity: {ppl:.2f}")
        if ppl > 10:
            recorded_samples.append((begin_loc, end_loc, ppl.item()))  # Record sample indices and perplexity
        prev_end_loc = end_loc
        if end_loc == (seq_len - 1):
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    # Optionally, print recorded samples info here
    # for idx, (start, end, perplexity) in enumerate(recorded_samples):
    #     print(f"Sample {idx + 1}: Start {start}, End {end}, Perplexity {perplexity:.2f}")
    return ppl


@torch.no_grad()
def calculate_perplexity_data(model, tokenizer, dataset: str, max_length: int, stride: int = 2048) -> float:
    print("Loading dataset...")
    encodings = get_dataset(dataset, tokenizer)
    seq_len = encodings.size(1)

    print("Calculating perplexity...")
    print(f"Sequence length: {seq_len}")
    print(f"Max length: {max_length}")
    print(f"Stride: {stride}")

    nlls = []
    prev_end_loc = 0
    recorded_samples = []  # Store samples that meet condition
    for begin_loc in (pbar := tqdm(range(0, seq_len - 1, stride))):
        end_loc = min(seq_len - 1, begin_loc + max_length)
        trg_len = end_loc - prev_end_loc  # Number of tokens to predict
        input_ids = encodings[:, begin_loc:end_loc + 1].to('cuda')  # +1 for labels

        with torch.no_grad():
            # Get logits from model
            outputs = model.model(input_ids[:, :-1], use_cache=False)
            logits = model.lm_head(outputs[0][..., -trg_len:, :])

            # Labels are last trg_len tokens
            labels = input_ids[:, -trg_len:].contiguous()

            # Compute NLL loss for batch
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        nlls.append(loss.to('cpu').to(torch.float32))
        ppl = torch.exp(torch.stack(nlls).mean())
        pbar.set_description(f"Perplexity: {ppl:.2f}")
        if ppl > 0:
            recorded_samples.append((begin_loc, end_loc, ppl.item()))  # Record sample indices and perplexity
        prev_end_loc = end_loc
        if end_loc == (seq_len - 1):
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    print(f"Recorded samples with PPL > 0: {len(recorded_samples)}")
    for idx, (start, end, perplexity) in enumerate(recorded_samples):
        print(f"Sample {idx + 1}: Start {start}, End {end}, Perplexity {perplexity:.2f}")
    for idx, (start, end, perplexity) in enumerate(recorded_samples):
        print(f"{perplexity:.2f}")
    return ppl


@torch.no_grad()
def calculate_perplexity_temp(model, tokenizer, dataset: str, max_length: int, stride: int = 2048) -> float:
    print("Loading dataset...")
    encodings = get_dataset(dataset, tokenizer)
    seq_len = encodings.size(1)

    print("Calculating perplexity...")
    print(f"Sequence length: {seq_len}")
    print(f"Max length: {max_length}")
    print(f"Stride: {stride}")

    nlls = []
    prev_end_loc = 0
    slice_count = 0   # Counter for number of slices processed
    for begin_loc in (pbar := tqdm(range(0, seq_len - 1, stride))):
        if slice_count >= 10:   # Only compute first 10 slices
            break
        end_loc = min(seq_len - 1, begin_loc + max_length)
        trg_len = end_loc - prev_end_loc  # Number of tokens to predict
        input_ids = encodings[:, begin_loc:end_loc + 1].to('cuda')  # +1 for labels

        with torch.no_grad():
            # Get logits from model
            outputs = model.model(input_ids[:, :-1], use_cache=False)
            logits = model.lm_head(outputs[0][..., -trg_len:, :])

            # Labels are last trg_len tokens
            labels = input_ids[:, -trg_len:].contiguous()

            # Compute NLL loss for batch
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        nlls.append(loss.to('cpu').to(torch.float32))
        ppl = torch.exp(torch.stack(nlls).mean())
        pbar.set_description(f"Perplexity: {ppl:.2f}")

        prev_end_loc = end_loc
        slice_count += 1
        if end_loc == (seq_len - 1):
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    return ppl


@torch.no_grad()
def eval(model, tokenizer, input_text, dataset):
    print("0. Model parameters summary")
    print_model_summary(model)

    print("1. Compression performance evaluation")
    print("1.1 Knowledge ability")
    eval_mmlu(model, tokenizer, ntrain=0, data_dir="data/MMLU")
    print("1.2 Reasoning ability")
    eval_mnli(model, tokenizer, ntrain=0, data_dir="data/MNLI")
    eval_qnli(model, tokenizer, ntrain=0, data_dir="data/QNLI")

    print("2. Generalization ability evaluation")
    # Reset peak memory stats
    torch.cuda.reset_peak_memory_stats()
    print("Perplexity evaluation on this model")
    #ppl = calculate_perplexity_temp(model, tokenizer, dataset, max_length=2048, stride=512)
    ppl = calculate_perplexity(model, tokenizer, dataset, max_length=2048, stride=512)
    print(f"{dataset} perplexity: {ppl}")
    dataset_ptb = "ptb_text_only"
    ppl_ptb = calculate_perplexity(model, tokenizer, dataset_ptb, max_length=2048, stride=512)
    print(f"{dataset_ptb} perplexity: {ppl_ptb}")
    peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)  # Convert to MB

    print("3. Inference efficiency evaluation")
    print(f"Peak memory usage during inference: {peak_memory:.2f} MB")
    # Compute FLOPs
    eval_flops(model, tokenizer, seqlen=128)
    # Performance report (output text shows sample answer quality)
    generated_text_after = generate_text(model, tokenizer, input_text)
    print("Generated text:", generated_text_after)
    prof_after = evaluate_model_performance(model, tokenizer, input_text, model.device)
    print("Inference performance report:")
    print(prof_after.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

    print("4. Trustworthiness evaluation")
    eval_advglu(model, tokenizer, ntrain=0, data_file='data/adv_glue/dev_ann.json', test_origin=False)
    eval_tQA(model, tokenizer, preset='qa', input_path='data/TruthfulQA/TruthfulQA.csv', device='cuda')


@torch.no_grad()
def eval_only_ppl(model, tokenizer, input_text, dataset):
    print("2. Generalization ability evaluation")
    # Reset peak memory stats
    torch.cuda.reset_peak_memory_stats()
    print("Perplexity evaluation on this model")
    ppl = calculate_perplexity(model, tokenizer, dataset, max_length=2048, stride=512)
    print(f"{dataset} perplexity: {ppl}")


@torch.no_grad()
def eval_only_ppl_data(model, tokenizer, input_text, dataset):
    print("2. Generalization ability evaluation")
    # Reset peak memory stats
    torch.cuda.reset_peak_memory_stats()
    print("Perplexity evaluation on this model")
    ppl = calculate_perplexity_data(model, tokenizer, dataset, max_length=2048, stride=512)
    print(f"{dataset} perplexity: {ppl}")
