import argparse
import json
import os
import logging
import re
import sys
import torch
import numpy as np
import datasets
import accelerate
import transformers
import wandb
import random

from tqdm.auto import tqdm
from pathlib import Path
from datasets import load_dataset
from typing import Any, Callable, Dict, Sequence, cast
from dataclasses import dataclass
from dataclasses_json import DataClassJsonMixin
# from torch.utils.tensorboard import SummaryWriter
from GEARLM import SimulatedGearLlamaForCausalLM, CompressionConfig, SimulatedGearMistralForCausalLM, MistralConfig, SimulatedGearQwen2ForCausalLM
from transformers import LlamaTokenizer, AutoTokenizer
from GEARLM.Simulated.utils.weight_compression import apply_low_rank_to_model
from GEARLM.Simulated.utils.calculate_kv_compress_ratio import (
    calculate_gear_kcvt_compression_ratio, 
    calculate_pcc_compact_compression_ratio,
    calculate_kcvt_compression_ratio,
    calculate_kivi_v2_compression_ratio,
    calculate_palu_50_compression_ratio
)

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

MODEL_GENERATION_SPLIT = "\nQuestion: "
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class EvaluationSample:
    """Wrapper around format evaluation sample."""

    question: str
    generation: str
    answer: str
    list_from_pred: list[str]
    list_from_answer: list[str]
    pred: float
    label: float
    is_pred_true: bool


@dataclass(frozen=True)
class EvaluationMetrics(DataClassJsonMixin):
    """Wrapper around aggregated evaluation metrics."""

    accuracy: float


@dataclass(frozen=True)
class EvaluationResults(DataClassJsonMixin):
    """Wrapper around evaluation results"""

    samples: list[EvaluationSample]
    metrics: EvaluationMetrics


def evaluate_pred_answer(pred_str, ans_str):
    pattern = "\d*\.?\d+"
    pred_str, ans_str = pred_str.replace(",", ""), ans_str.replace(",", "")
    pred_list = re.findall(pattern, pred_str)
    gold_list = re.findall(pattern, ans_str)
    if len(pred_list) >= 1:
        pred = float(pred_list[-1])
        gold = float(gold_list[-1])
        is_pred_true = pred == gold
    else:
        is_pred_true = False
        pred = None
        gold = float(gold_list[-1])
    return (
        is_pred_true,
        pred,
        pred_list,
        gold,
        gold_list,
    )


def test_answer(pred_str, ans_str):
    pattern = "\d*\.?\d+"
    pred = re.findall(pattern, pred_str)
    if len(pred) >= 1:
        print("#####\n Pred string:", pred_str, "\n pred_list", pred)
        pred = float(pred[-1].replace(",", ""))
        gold = re.findall(pattern, ans_str)
        print("\n Gold_answer", ans_str, "\n gold_list", gold)
        gold = float(gold[-1].replace(",", ""))
        print("\n result", gold, pred, gold == pred)
        return pred == gold
    else:
        return False


def parse_pred_ans(filename):
    with open(filename) as fd:
        lines = fd.readlines()
    am, a = None, None
    num_q, acc = 0, 0
    current_mode = "none"
    questions = []
    ans_pred = []
    ans_gold = []
    am_others = []
    for l in lines:
        if l.startswith("Q: "):
            if am is not None and a is not None:
                questions.append(q)
                ans_pred.append(am)
                ans_gold.append(a)
                if test_answer(am, a):
                    acc += 1
            current_mode = "q"
            q = l
            num_q += 1
        elif l.startswith("A_model:"):
            current_mode = "am"
            am = l
        elif l.startswith("A:"):
            current_mode = "a"
            a = l
        # TODO
        elif current_mode == "am" and l.startswith("Question: "):
            current_mode = "am_other"
            am_other = l
        else:
            if current_mode == "q":
                q += l
            elif current_mode == "am":
                am += l
            elif current_mode == "a":
                a += l
            elif current_mode == "am_other":
                am_other += l
            else:
                raise ValueError(current_mode)

    questions.append(q)
    ans_pred.append(am)
    ans_gold.append(a)
    am_others.append(am_other)
    if test_answer(am, a):
        acc += 1
    print("######\n num_q %d correct %d ratio %.4f" % (num_q, acc, float(acc / num_q)))
    return questions, ans_pred, ans_gold

def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def setup_wandb(args):
    """Initializes a new wandb run."""
    run_name = args.wandb_run_name
    if run_name is None:
        # Create a descriptive run name if not provided
        run_name = (
            f"{args.model.split('/')[-1]}_{args.compress_method}_r{args.rank}"
            f"_b{args.quantize_bit}_g{args.group_size}_l{args.loop}"
            f"_ia{args.input_axis}_ea{args.error_axis}"
        )
    wandb.init(project=args.wandb_project, name=run_name, config=args)


def log_accuracy_to_wandb(accuracy: float):
    """Logs accuracy to the current wandb run and finishes it."""
    wandb.log({"accuracy": accuracy})
    wandb.finish()


def log_args(args: argparse.Namespace):
    """Logs the arguments to the logger."""
    logging.info("Running with the following configuration:")
    for arg, value in sorted(vars(args).items()):
        logging.info(f"  - {arg}: {value}")


def set_seed(seed: int):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class StoppingCriteriaSub(transformers.StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        last_token = input_ids[0][-1]
        for stop in self.stops:
            if tokenizer.decode(stop) == tokenizer.decode(last_token):
                return True
        return False
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate GSM8K Dataset")
    parser.add_argument(
        "--model", type=str, default="meta-llama/Llama-2-7b", help="Model name or path."
    )
    parser.add_argument(
        "--prompt_file", type=str, default="gsm8k_prompt_original.txt", help=""
    )
    parser.add_argument("--hf_token", type=str, default=None, help="")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size.")
    parser.add_argument("--example_subset", type=str, default=None, help="")
    parser.add_argument("--max_length", type=int, default=None, help="")
    parser.add_argument("--max_new_tokens", type=int, default=256, help="")
    parser.add_argument("--model_max_length", type=int, default=4096, help="")
    parser.add_argument("--do_sample", action="store_true", default=False, help="")
    parser.add_argument("--temperature", type=float, default=0.8, help="")
    parser.add_argument("--top_k", type=int, default=50, help="")
    parser.add_argument("--top_p", type=float, default=0.95, help="")
    parser.add_argument(
        "--generation_split", type=str, default=MODEL_GENERATION_SPLIT, help=""
    )
    parser.add_argument(
        "--root_output_dir", type=str, default="outputs", help="Root output dir"
    )
    parser.add_argument("--debug", action="store_true", default=False, help="")
    parser.add_argument("--compress_method", type=str, default="None", help="")
    parser.add_argument("--rank", type=float, default=2.0, help="")
    parser.add_argument("--rankv", type=float, default=2.0, help="")
    parser.add_argument("--prefillrank", type=float, default=2.0, help="rank compared with smaller dimension set to K cache.")
    parser.add_argument("--prefillrankv", type=float, default=2.0, help="rank compared with smaller dimension set to V cache.")
    parser.add_argument("--merge_group_k", type=float, default=0.0, help="")
    parser.add_argument("--merge_group_v", type=float, default=0.0, help="")
    parser.add_argument("--loop", type=int, default=2, help="")
    parser.add_argument("--quantize_bit", type=int, default=4, help="")
    parser.add_argument("--group_num", type=int, default=0, help="")
    parser.add_argument("--group_size", type=int, default=0, help="")
    parser.add_argument("--top_kprun", type=float, default=0.0, help="")
    parser.add_argument("--left", type=float, default=0.01, help="")
    parser.add_argument("--attention_number", type=int, default=100, help="")
    parser.add_argument("--gpu", type=int, default=0, help="")

    parser.add_argument("--heavy_size", type=int, default=0, help="")
    parser.add_argument("--recent_size", type=int, default=0, help="")
    parser.add_argument("--streaming", action="store_true", default=False, help="")
    parser.add_argument("--streaming_gap", type=int, default=0, help="")
    parser.add_argument("--zero_shot", action="store_true", default=False, help="")
    parser.add_argument("--stream_grouping", action="store_true", default=False, help="Use streaming mode.")
    parser.add_argument("--token_preserving", action="store_true", default=False, help="")
    parser.add_argument("--start", type=int, default=0, help="")
    parser.add_argument("--locality", type=int, default=0, help="")

    # Wandb args
    parser.add_argument(
        "--use_wandb", action="store_true", help="Enable logging to wandb."
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default="gsm8k-evaluation",
        help="Wandb project name.",
    )
    parser.add_argument(
        "--wandb_run_name",
        type=str,
        default=None,
        help="Wandb run name. If not provided, a name will be generated.",
    )
    parser.add_argument("--seed", type=int, default=None, help="Set a seed for reproducibility.")
    parser.add_argument("--input_axis", type=str, default='right', help="Axis for input Hadamard transform ('left', 'right', or 'None').")
    parser.add_argument("--error_axis", type=str, default='right', help="Axis for error Hadamard transform ('left', 'right', or 'None').")
    parser.add_argument("--first_method", type=str, default='None', help="First method for sandbox.")
    parser.add_argument("--first_transform", type=str, default='None', help="First transform for sandbox.")
    parser.add_argument("--second_method", type=str, default='None', help="Second method for sandbox.")
    parser.add_argument("--second_transform", type=str, default='None', help="Second transform for sandbox.")
    parser.add_argument("--hla_rank", type=int, default=0, help="HLA rank for sandbox.")
    parser.add_argument("--max_batches", type=int, default=None, help="Maximum number of batches to process for debugging. If None, process all batches.")
    parser.add_argument("--low_rank_weight", action="store_true", default=False, help="Apply low-rank approximation to model weights using SVD.")
    parser.add_argument("--weight_rank", type=float, default=50.0, help="Percentage of rank to keep for weight low-rank approximation (default: 50.0%). Value between 0-100.")
    parser.add_argument("--low_rank_mode", type=str, default="svd", choices=["svd", "approx"], help="Mode for low-rank approximation: 'svd' for exact SVD, 'approx' for power iteration approximation.")
    parser.add_argument("--power_iter_loop", type=int, default=3, help="Number of power iteration loops for approximation mode (default: 3).")
    parser.add_argument("--use_quantized_residual", action="store_true", default=False, help="Use quantized residual (SVD + Q(W - SVD(W))) for better approximation.")
    parser.add_argument("--weight_quant_bits", type=int, default=8, help="Number of bits for residual quantization (default: 8).")
    parser.add_argument("--weight_transform", type=str, default="none", choices=["none", "hadamard", "pca", "cov"], help="Transform to apply before low-rank approximation (default: 'none').")
    parser.add_argument("--kv_transform", type=str, default="none", help="Transform to apply in KV cache compression (default: 'none').")
    parser.add_argument("--use_awq", action="store_true", default=False, help="Use activation-aware quantization (AWQ) for residual quantization.")
    parser.add_argument("--awq_calibration_samples", type=int, default=128, help="Number of calibration samples for AWQ (default: 128).")
 
    args = parser.parse_args()
 
    if args.debug:
        import ipdb

        ipdb.set_trace()

    if args.seed is not None:
        set_seed(args.seed)

    if args.use_wandb:
        setup_wandb(args)

    # Setup output dir
    root_output_dir = Path(args.root_output_dir)
    output_dir = f"cot_{args.prompt_file.split('.')[0]}"
    if args.example_subset is not None:
        output_dir += f"_subset-{args.example_subset}"
    output_dir = root_output_dir / f"{args.model.split('/')[-1]}" / output_dir
    output_dir.mkdir(exist_ok=True, parents=True)
    generation_file = (
        output_dir / f"generation_results_subset-{args.example_subset}.txt"
    )
    evaluation_result_file = output_dir / f"evaluation_gsm8k.json"

    split = "test" if args.example_subset is None else f"test[{args.example_subset}]"
    eval_dataset = load_dataset("gsm8k", "main", split=split, ignore_verifications=True)
    # tb_writter = SummaryWriter(log_dir=str(output_dir.resolve()))
    logging.basicConfig(
        filename=os.path.join(output_dir.resolve(), "log.txt"),
        filemode="a",
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))

    log_args(args)

    # Load Model and Tokenizer
    model_kwargs = {}

    if "Llama" in args.model or "Mistral" in args.model or "Qwen2" in args.model:
        model_kwargs["torch_dtype"] = torch.float16
        model_kwargs["device_map"] = "auto"
        model_kwargs["token"] = args.hf_token
        model_kwargs["cache_dir"] = "../cache"
    else:
        raise ValueError(f"Model {args.model} not supported")
    
    config = transformers.AutoConfig.from_pretrained(
        args.model,
        use_auth_token=True,
        token=args.hf_token,
        use_flash_attn=False,
        trust_remote_code=True,
    )

    compress_config = (
        None
        if args.compress_method == "None"
        else CompressionConfig(
            compress_method=args.compress_method,
            rank=args.rank,
            rankv=args.rankv,
            prefill_rank = args.prefillrank,
            prefill_rankv = args.prefillrankv,
            
            loop=args.loop,
            quantize_bit=args.quantize_bit,
            group_num=args.group_num,
            group_size = args.group_size,
            top_k=args.top_kprun,
            left=args.left,
            attention_number=args.attention_number,
            device_num=args.gpu,
            batch_num=args.batch_size,

            streaming=args.streaming,
            streaming_gap=args.streaming_gap,
            stream_grouping=args.stream_grouping,
            input_axis=args.input_axis if args.input_axis.lower() != 'none' else None,
            error_axis=args.error_axis if args.error_axis.lower() != 'none' else None,
            first_method=args.first_method if args.first_method.lower() != 'none' else None,
            first_transform=args.first_transform if args.first_transform.lower() != 'none' else None,
            second_method=args.second_method if args.second_method.lower() != 'none' else None,
            second_transform=args.second_transform if args.second_transform.lower() != 'none' else None,
            hla_rank=args.hla_rank,
            kv_transform=args.kv_transform if args.kv_transform.lower() != 'none' else None,
        )
    )
    if compress_config is not None:
        compress_config.copy_for_all_attention()
        compress_config.calculate_compress_ratio_list(4095, 4096)
    
    if "Llama" in args.model:
       
        model = SimulatedGearLlamaForCausalLM.from_pretrained(
            args.model,
            config=config,
            **model_kwargs,
            compress_config=compress_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            args.model,
            token=args.hf_token,
            padding_side="left",
            model_max_length=args.model_max_length,
            use_fast=False,
            cache_dir="../cache",
        )
        tokenizer.pad_token = tokenizer.eos_token
    elif "Mistral" in args.model:
        from transformers import AutoTokenizer
        config = MistralConfig.from_pretrained(
            args.model,
            use_auth_token=True,
            token=args.hf_token,
            use_flash_attn=False,
            trust_remote_code=True,
        )
        model = SimulatedGearMistralForCausalLM.from_pretrained(
            args.model,
            config=config,
            **model_kwargs,
            trust_remote_code=True,
            compress_config=compress_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            args.model,
            token=args.hf_token,
            padding_side="left",
            model_max_length=args.model_max_length,
            use_fast=False,
            cache_dir="../cache",
            trust_remote_code=True,
        )
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token
    elif "Qwen2" in args.model:
        model = SimulatedGearQwen2ForCausalLM.from_pretrained(
            args.model,
            config=config,
            **model_kwargs,
            compress_config=compress_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            args.model,
            token=args.hf_token,
            padding_side="left",
            model_max_length=args.model_max_length,
            use_fast=False,
            cache_dir="../cache",
            trust_remote_code=True,
        )
        tokenizer.pad_token = tokenizer.eos_token

    if args.compress_method == "GEAR":
        calculate_gear_kcvt_compression_ratio(args, config)
    elif args.compress_method == "PCC_COV_OUTLIER_COMPACT":
        calculate_pcc_compact_compression_ratio(args, config)
    elif args.compress_method == "KCVT":
        calculate_kcvt_compression_ratio(args, config)
    elif args.compress_method == "KIVI_V2":
        calculate_kivi_v2_compression_ratio(args, config)
    elif args.compress_method == "PALU_50":
        calculate_palu_50_compression_ratio(args, config)

    # Load prompt template early for AWQ calibration
    with open(f"lib_prompt/{args.prompt_file}", "r") as handle:
        prompt_cot = handle.read()
    
    # Apply low-rank approximation if requested
    if args.low_rank_weight:
        activation_scales = None
        
        # Collect activation statistics if using AWQ
        if args.use_awq and args.use_quantized_residual:
            logging.info(f"Collecting activation statistics using AWQ with {args.awq_calibration_samples} calibration samples...")
            from awq import collect_linear_activations
            
            # Create calibration dataloader
            calibration_dataloader = torch.utils.data.DataLoader(
                cast(torch.utils.data.Dataset, eval_dataset),
                batch_size=args.batch_size,
                shuffle=True,  # Shuffle for better calibration
            )
            
            # Create a wrapper to convert GSM8K data to model inputs
            class CalibrationDataWrapper:
                def __init__(self, dataloader, tokenizer, prompt_cot):
                    self.dataloader = dataloader
                    self.tokenizer = tokenizer
                    self.prompt_cot = prompt_cot
                    
                def __iter__(self):
                    for batch in self.dataloader:
                        questions = batch["question"]
                        prompts = [
                            self.prompt_cot + "\nQuestion: " + question + "\n"
                            for question in questions
                        ]
                        inputs = self.tokenizer(
                            prompts,
                            return_tensors="pt",
                            padding="longest",
                            truncation=True,
                            max_length=512,  # Limit length for calibration
                        )
                        inputs = {k: v.to("cuda") for k, v in inputs.items()}
                        yield inputs
            
            calibration_wrapper = CalibrationDataWrapper(calibration_dataloader, tokenizer, prompt_cot)
            activation_scales = collect_linear_activations(model, calibration_wrapper, args.awq_calibration_samples)
            logging.info(f"Collected activation scales for {len(activation_scales)} layers")
        
        method_desc = f"{args.low_rank_mode}"
        if args.use_quantized_residual:
            method_desc += f" + {args.weight_quant_bits}-bit quantized residual"
            if args.use_awq:
                method_desc += " + AWQ"
        if args.weight_transform != "none":
            method_desc += f" + {args.weight_transform} transform"
        logging.info(f"Applying low-rank approximation with {args.weight_rank}% rank retention, mode: {method_desc}")
        apply_low_rank_to_model(model, args.weight_rank, args.low_rank_mode, args.power_iter_loop, 
                              args.use_quantized_residual, args.weight_quant_bits, args.weight_transform, args.left,
                              activation_scales)
    
    logging.info("Preprocessing the dataset.")
    
    dataloader = torch.utils.data.DataLoader(
        cast(torch.utils.data.Dataset, eval_dataset),
        batch_size=args.batch_size,
    )
    all_samples = []
    all_question, all_generation, all_answer = [], [], []
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Evaluate GSM8K")):
            # Early exit if max_batches is specified
            if args.max_batches is not None and batch_idx >= args.max_batches:
                logging.info(f"Early exit: Processed {args.max_batches} batches for debugging")
                break
            questions = batch["question"]
            answers = batch["answer"]
            if args.zero_shot is True:
                prompt_cot = "answer the question through the form of The answer is xxx. Do not generate others."
                prompts = [
                    prompt_cot + "\nQuestion: " + question + "\n"
                    for question in questions
                ]
            else:
                prompts = [
                    prompt_cot + "\nQuestion: " + question + "\n"
                    for question in questions
                ]

            inputs = tokenizer(
                prompts,
                return_tensors="pt",
                padding="longest",
                truncation=True,
            )
            print(inputs.input_ids.shape)
            inputs = inputs.to("cuda")
            generate_kwargs = dict(
                return_dict_in_generate=True,
                max_length=args.max_length,
                max_new_tokens=args.max_new_tokens,
                output_scores=True,
                pad_token_id=tokenizer.eos_token_id,
                use_cache=True,
            )
            if args.do_sample:
                generate_kwargs["do_sample"] = True
                generate_kwargs["temperature"] = args.temperature
                generate_kwargs["top_k"] = args.top_k
                generate_kwargs["top_p"] = args.top_p
            else:
                generate_kwargs["do_sample"] = False
                generate_kwargs["temperature"] = None
                generate_kwargs["top_k"] = None
                generate_kwargs["top_p"] = None
            outputs = model.generate(**inputs, **generate_kwargs)
            generations = tokenizer.batch_decode(
                outputs.sequences[:, inputs.input_ids.shape[1] :],
                skip_special_tokens=True,
            )

            all_question += questions
            all_generation += generations
            all_answer += answers

            for question, generation, answer in zip(questions, generations, answers):
                is_pred_true, pred, pred_list, gold, gold_list = evaluate_pred_answer(
                    generation.split(args.generation_split)[0], answer
                )
                sample = EvaluationSample(
                    question=question,
                    generation=generation,
                    answer=answer,
                    list_from_pred=pred_list,
                    list_from_answer=gold_list,
                    pred=pred,
                    label=gold,
                    is_pred_true=is_pred_true,
                )
                all_samples.append(sample)

        accuracy = sum([sample.is_pred_true for sample in all_samples]) / len(
            all_samples
        )
        evaluation_metric = EvaluationMetrics(accuracy=accuracy)
        evaluation_result = EvaluationResults(
            samples=all_samples,
            metrics=evaluation_metric,
        )

    # tb_writter.add_scalar("accuracy", accuracy, 1)
    logging.info(f"Accuracy: {accuracy}")

    with evaluation_result_file.open("w") as handle:
        json.dump(evaluation_result.to_dict(), handle)

    with generation_file.open("w", encoding="utf-8") as handle:
        for question, generation, answer in zip(
            all_question, all_generation, all_answer
        ):
            handle.write(
                "Q: %s\nA_model:\n%s\nA:\n%s\n\n" % (question, generation, answer)
            )

    if args.use_wandb:
        log_accuracy_to_wandb(accuracy)
        
    
