import os
from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from ...util.globals import *
from ...util.nethook import Trace, set_requires_grad
from ...util.runningstats import CombinedStat, Mean, NormMean, SecondMoment, tally
import numpy.linalg as LA
from .tok_dataset import (
    TokenizedDataset,
    dict_to_,
    flatten_masked_batch,
    length_collation,
)

STAT_TYPES = {
    "mom2": SecondMoment,
    "mean": Mean,
    "norm_mean": NormMean,
}


def main():
    """
    Command-line utility to precompute cached stats.
    """
    import argparse

    parser = argparse.ArgumentParser(description="ROME Statistics Collector")

    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)

    aa("--model_name", default="gpt2-xl", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"])
    aa("--dataset", default="wikipedia", choices=["wikitext", "wikipedia","comprehension","math","Chinese"])
    aa("--layers", default=[17], type=lambda x: list(map(int, x.split(","))))
    aa("--to_collect", default=["mom2"], type=lambda x: x.split(","))
    aa("--sample_size", default=100000, type=lambda x: None if x == "all" else int(x))
    aa("--batch_tokens", default=None, type=lambda x: None if x == "any" else int(x))
    aa("--precision", default="float32", choices=["float64", "float32", "float16"])
    aa("--stats_dir", default=STATS_DIR)
    aa("--download", default=1, type=int, chodata_typices=[0, 1])

    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForCausalLM.from_pretrained(args.model_name).eval().cuda()
    set_requires_grad(False, model)

    for layer_num in args.layers:
        print(
            f"Computing stats for layer {layer_num} of {args.model_name} "
            f'over {args.sample_size or "all"} samples of {args.dataset}. '
            "Note, the statistics are collected over the inputs to the second MLP layer, "
            "or equivalently the outputs of the first MLP layer."
        )
        proj_layer_name = "c_proj" if "gpt2" in args.model_name else "fc_out"
        layer_name = f"transformer.h.{layer_num}.mlp.{proj_layer_name}"

        # layer_stats(
        #     model,
        #     tokenizer,
        #     layer_name,
        #     args.stats_dir,
        #     args.dataset,
        #     args.to_collect,
        #     rank=None,
        #     sample_size=args.sample_size,
        #     precision=args.precision,
        #     batch_tokens=args.batch_tokens,
        #     download=args.download,
            
        # )


def layer_stats(
    model,
    tokenizer,
    layer_name,
    stats_dir,
    ds_name,
    to_collect,
    rank,
    force_recompute,
    model_name=None,
    sample_size=None,
    precision=None,
    batch_tokens=None,
    download=True,
    progress=tqdm,
    hparams=None,
    
):
    """
    Function to load or compute cached stats.
    """

    # def get_ds():
    #     # Load_From_File
    #     # from datasets import Dataset
    #     # raw_ds = Dataset.from_file('XXX/XXX/wikipedia-train.arrow')
    #     # raw_ds = {'train': raw_ds}
    #     if ds_name=="wikipedia":
    #         raw_ds=load_dataset("wikipedia","20200501.en")
    #     elif ds_name=="Chinese":
    #         raw_ds=load_dataset("json", data_files="/mnt/hdfs/lss/lssedit2/EasyEdit/wikidump_zh/AA/wiki_00.json")

    #     # raw_ds = load_dataset(
    #     #     ds_name,
    #     #     dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en",comprehension="rajpurkar/squad",math="openai/gsm8k")[ds_name]
    #     # )
    #     if hasattr(model.config, 'n_positions'):
    #         maxlen = model.config.n_positions
    #     elif hasattr(model.config, 'max_sequence_length'):
    #         maxlen = model.config.max_sequence_length
    #     elif hasattr(model.config, 'max_position_embeddings'):
    #         maxlen = model.config.max_position_embeddings
    #     elif hasattr(model.config,'seq_length'):
    #         maxlen = model.config.seq_length
    #     else:
    #         raise NotImplementedError
                
    #     if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
    #         if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
    #             maxlen = model.config.sliding_window or 4096
    #         else:
    #             maxlen = 4096
    #     if hasattr(model.config, 'model_type') and 'qwen2' in model.config.model_type:
    #         maxlen = 4096

    #     if batch_tokens is not None and batch_tokens < maxlen:
    #         maxlen = batch_tokens
    #     return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)

    # # Continue with computation of statistics
    # batch_size = 100  # Examine this many dataset texts at once
    # if hasattr(model.config, 'n_positions'):
    #     npos = model.config.n_positions
    # elif hasattr(model.config, 'max_sequence_length'):
    #     npos = model.config.max_sequence_length
    # elif hasattr(model.config, 'max_position_embeddings'):
    #     npos = model.config.max_position_embeddings
    # elif hasattr(model.config,'seq_length'):
    #     npos = model.config.seq_length
    # else:
    #     raise NotImplementedError
        
    # if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
    #     if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
    #         npos = model.config.sliding_window or 4096
    #     else:
    #         npos = 4096
    # if hasattr(model.config, 'model_type') and 'qwen2' in model.config.model_type:
    #         npos = 4096

    # if batch_tokens is None:
    #     batch_tokens = npos * 3  # Sort and divide into batches with this many tokens
    # if precision is None:
    #     precision = "float64"
    # dtype = getattr(torch, precision)
    # size_suffix = "" if sample_size is None else f"_{sample_size}"
    # if batch_tokens < npos:
    #     size_suffix = "_t{batch_tokens}" + size_suffix
    # if model_name is None:
    #     # model_name = model.config._name_or_path.replace("/", "_")
    #     model_name = model.config._name_or_path.rsplit("/")[-1]

    # stats_dir = Path(stats_dir)
    # file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
    # filename = stats_dir / file_extension

    # print(f"Computing Cov locally....")

    # ds = get_ds() if not filename.exists() else None

    # if progress is None:
    #     progress = lambda x: x

    # stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect})
    # loader = tally(
    #     stat,
    #     ds,
    #     cache=(filename if not force_recompute else None),
    #     sample_size=sample_size,
    #     batch_size=batch_size,
    #     collate_fn=length_collation(batch_tokens),
    #     pin_memory=True,
    #     random_sample=1,
    #     num_workers=2,
    # )
    # batch_count = -(-(sample_size or len(ds)) // batch_size)
    # with torch.no_grad():
    #     for batch_group in progress(loader, total=batch_count):
    #         for batch in batch_group:
    #             batch = dict_to_(batch, f"cuda:{hparams.device}")
    #             with Trace(
    #                 model, layer_name, retain_input=True, retain_output=False, stop=True
    #             ) as tr:
    #                 model(**batch)
    #             feats = flatten_masked_batch(tr.input, batch["attention_mask"])
    #             # feats = flatten_masked_batch(tr.output, batch["attention_mask"])
    #             feats = feats.to(dtype=dtype)
    #             stat.add(feats)
    # return stat
    #force_recompute=True
    print('rank',rank,'force',force_recompute)
    def get_ds():
        if ds_name == "wikipedia":
            raw_ds = load_dataset("wikipedia", "20200501.en", split="train[:200000]")
        elif ds_name == "Chinese":
            raw_ds = load_dataset("json", data_files="/mnt/hdfs/lss/lssedit2/EasyEdit/wikidump_zh/AA/wiki_00.json")
        elif ds_name == "math":
            raw_ds = load_dataset("json", data_files="/mnt/hdfs/lss/lssedit2/KnowEdit/benchmark/gsm8k.json")
        elif ds_name == "SST-2":
            raw_ds = load_dataset("json", data_files="/mnt/hdfs/lss/lssedit2/KnowEdit/benchmark/SST-2.json")
        # raw_ds = load_dataset(
        #     ds_name,
        #     dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en",comprehension="rajpurkar/squad",math="openai/gsm8k")[ds_name]
        # )
        if hasattr(model.config, 'n_positions'):
            maxlen = model.config.n_positions
        elif hasattr(model.config, 'max_sequence_length'):
            maxlen = model.config.max_sequence_length
        elif hasattr(model.config, 'max_position_embeddings'):
            maxlen = model.config.max_position_embeddings
        elif hasattr(model.config, 'seq_length'):
            maxlen = model.config.seq_length
        else:
            raise NotImplementedError
                
        if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
            if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
                maxlen = model.config.sliding_window or 4096
            else:
                maxlen = 4096
        if hasattr(model.config, 'model_type') and 'qwen2' in model.config.model_type:
            maxlen = 4096

        if batch_tokens is not None and batch_tokens < maxlen:
            maxlen = batch_tokens
        return TokenizedDataset(raw_ds, tokenizer, maxlen=maxlen)

    # 低秩近似函数
    # def low_rank_approximation(matrix, rank):
    #     """对矩阵进行低秩近似"""
    #     # 计算奇异值分解
    #     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    #     U, S, Vh = torch.linalg.svd(matrix, full_matrices=False)
    #     U_k = U[:, :rank]
    #     S_k = torch.diag(S[:rank])
    #     Vh_k = Vh[:rank, :]
    #     print('finish calculation')
    #     # 保留前rank个奇异值和对应的奇异向量
    #     return torch.mm(torch.mm(U_k, S_k), Vh_k)
    def low_rank_approximation(matrix, rank):
       device = matrix.device
       m, n = matrix.shape
    
       # 随机投影矩阵
       Omega = torch.randn(n, rank, device=device, dtype=matrix.dtype)
    
       # 幂迭代
       Y = matrix @ Omega
       for _ in range(2):  # 通常2-3次迭代足够
          Y = matrix @ (matrix.T @ Y)
    
       # QR分解
       Q, _ = torch.linalg.qr(Y)
    
       # 小矩阵SVD
       B = Q.T @ matrix
       U_tilde, S, Vh = torch.linalg.svd(B, full_matrices=False)
    
       # 重建低秩矩阵
       U = Q @ U_tilde
       return (U[:, :rank] * S[:rank]) @ Vh[:rank, :]
       # Continue with computation of statistics
    batch_size = 1000  # Examine this many dataset texts at once
    if hasattr(model.config, 'n_positions'):
        npos = model.config.n_positions
    elif hasattr(model.config, 'max_sequence_length'):
        npos = model.config.max_sequence_length
    elif hasattr(model.config, 'max_position_embeddings'):
        npos = model.config.max_position_embeddings
    elif hasattr(model.config, 'seq_length'):
        npos = model.config.seq_length
    else:
        raise NotImplementedError
        
    if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
        if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
            npos = model.config.sliding_window or 4096
        else:
            npos = 4096
    if hasattr(model.config, 'model_type') and 'qwen2' in model.config.model_type:
            npos = 4096

    if batch_tokens is None:
        batch_tokens = npos * 3  # Sort and divide into batches with this many tokens
    if precision is None:
        precision = "float64"
    dtype = getattr(torch, precision)
    size_suffix = "" if sample_size is None else f"_{sample_size}"
    if batch_tokens < npos:
        size_suffix = "_t{batch_tokens}" + size_suffix
    if model_name is None:
        # model_name = model.config._name_or_path.replace("/", "_")
        model_name = model.config._name_or_path.rsplit("/")[-1]

    # 修改文件名以反映低秩近似
    if rank is not None:
        file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_rank{rank}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
    else:
        file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
    stats_dir = Path(stats_dir)
    filename = stats_dir / file_extension

    print(f"Computing stats (low-rank: {rank is not None})....")

    ds = get_ds() if not filename.exists() else None

    if progress is None:
        progress = lambda x: x

    stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect})
    loader = tally(
        stat,
        ds,
        cache=(filename if not force_recompute else None),
        sample_size=sample_size,
        batch_size=batch_size,
        collate_fn=length_collation(batch_tokens),
        pin_memory=True,
        random_sample=1,
        num_workers=8,
    )
    batch_count = -(-(sample_size or len(ds)) // batch_size)
    
    # 用于累积特征的张量
    all_feats = []
    
    with torch.no_grad():
      for batch_group in progress(loader, total=batch_count):
        for batch in batch_group:
            batch = dict_to_(batch, f"cuda:{hparams.device}")
            with Trace(
                model, layer_name, retain_input=True, retain_output=False, stop=True
            ) as tr:
                model(**batch)
            feats = flatten_masked_batch(tr.input, batch["attention_mask"])
            feats = feats.to(dtype=dtype)
            
            if rank is not None:
                print(f"Applying low-rank approximation with rank {rank} to feature matrix of shape {feats.shape}")
                feats = low_rank_approximation(feats, rank).to(device=hparams.device)
                print('feats.shape',feats.shape)
                # 直接添加处理后的特征到统计信息中
                stat.add(feats)
                # 释放当前批次的内存
                del feats
                torch.cuda.empty_cache()
            else:
                # 不使用低秩近似时直接添加原始特征
                stat.add(feats)
                del feats
                torch.cuda.empty_cache()

    return stat

if __name__ == "__main__":
    main()
