import os
import json
import pickle
import gzip
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from repe import repe_pipeline_registry
from utils.data_loader import DatasetLoader
import argparse


class NAITAnalyzer:
    def __init__(self, args):
        """
        初始化分析器
        :param config: 配置字典，包含：
            - model_path: 模型路径
            - dataset_path: 数据集根目录
            - reference_ds: 参考数据集名称
            - target_ds: 目标数据集名称
            - hidden_layers: 隐藏层配置（默认倒序所有层）
            - direction_method: 方向计算方法（默认pca）
        """
        self.model_name_or_path = args.model_name_or_path
        self.dataset_path = args.dataset_path
        self.na_label = args.na_label
        self.reference_ds = args.reference_ds
        self.ref_lang = args.ref_lang
        self.ntrain = args.ntrain
        self.target_ds = args.target_ds
        self.tar_lang = args.tar_lang
        self.batch_size = args.batch_size
        self.rep_reader = None
        self.save_dir = args.save_dir
        self._init_model()
        self.loader = DatasetLoader(
        dataset_path=self.dataset_path,
        ntrain=self.ntrain,  
        tokenizer=self.tokenizer
    )

    def _init_model(self):
        """初始化模型和tokenizer"""
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name_or_path, 
            device_map="auto",
            # torch_dtype=torch.float16
        )
        use_fast = "LlamaForCausalLM" not in self.model.config.architectures
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name_or_path,
            use_fast_tokenizer=use_fast,
            padding_side="left",
            legacy=False
        )
        repe_pipeline_registry()
        self.rep_reading_pipeline =  pipeline("rep-reading", model=self.model, tokenizer=self.tokenizer)
        self.hidden_layers=list(range(-1, -self.model.config.num_hidden_layers, -1))
        # self.tokenizer.pad_token_id = self.model.config.bos_token_id
        self.tokenizer.pad_token_id = 0

    def prepare_reference(self):
        """准备参考数据集并生成阅读器"""
        print(f"Loading reference dataset: {self.reference_ds}")
        self.ref_dataset = self.loader.load_dataset(self.reference_ds,lang = self.ref_lang)
        print(self.ref_dataset[0])
        if not self.ntrain:  
            self.ntrain = len(self.ref_dataset)
            print(f"ntrain not provided. Set to length of reference dataset: {self.ntrain}")

    def build_reader(self):
        # 创建rep阅读器
        self.rep_pipeline = pipeline(
            "rep-reading", 
            model=self.model, 
            tokenizer=self.tokenizer
        )
        print(f"Extracting the reference dataset: {self.reference_ds} neurons activity")
        self.rep_reader = self.rep_pipeline.get_directions(
            [item[args.na_label] for item in self.ref_dataset],
            rep_token = 1,
            hidden_layers=self.hidden_layers,
            direction_method= 'pca',
            batch_size=self.batch_size
        )

    def load_reader(self,reader_path):
        # os.makedirs(self.save_dir, exist_ok=True)
        # filename = f"{self.reference_ds}_rep_ntrain_{self.ntrain}_reader.pkl.gz"
        # path = os.path.join(self.save_dir, filename)
        with gzip.open(reader_path, 'rb') as file:
            self.rep_reader = pickle.load(file)

    def save_reader(self, save_dir="results"):
        """保存生成的阅读器"""
        os.makedirs(self.save_dir, exist_ok=True)
        if self.ref_lang != None:
            filename = f"{self.reference_ds}_{self.ref_lang}_rep_ntrain_{self.ntrain}_reader.pkl.gz"
        else:
            filename = f"{self.reference_ds}_rep_ntrain_{self.ntrain}_reader.pkl.gz"
        path = os.path.join(self.save_dir, filename)
        
        with gzip.open(path, 'wb') as f:
            pickle.dump(self.rep_reader, f)
        print(f"Reader saved to {path}")


    def analyze_target(self):
        """分析目标数据集"""
        print(f"Loading target dataset: {self.target_ds}")
        
        self.target_dataset = self.loader.load_dataset(self.target_ds, lang = self.tar_lang)
        print(self.target_dataset[0])
        # 计算投影
        self.projected = self.rep_reading_pipeline(
            [item['QAPair'] for item in self.target_dataset],
            rep_token = 0,
            rep_reader=self.rep_reader,
            hidden_layers=self.hidden_layers,
            batch_size=self.batch_size
        )

    def save_results(self, output_dir="results"):
        """保存分析结果"""
        os.makedirs(output_dir, exist_ok=True)
        if self.tar_lang != None:
            filename = f"{self.target_ds}_{self.tar_lang}_active_by_{self.reference_ds}_{self.ref_lang}_ntrain_{self.ntrain}.json"
        else:
            filename = f"{self.target_ds}_active_by_{self.reference_ds}_ntrain_{self.ntrain}.json"
        output_path = os.path.join(output_dir, filename)
        
        results = [{
            "ori_data_format": item.get("ori_data_format"),
            args.na_label: item.get(args.na_label),
            "score": float(np.mean(list(status.values())))
        } for item, status in zip(self.target_dataset, self.projected)]
        
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=4, ensure_ascii=False)
        print(f"Results saved to {output_path}")

# 在原有类实现基础上修改main部分
if __name__ == "__main__":
    # 参数解析器配置
    parser = argparse.ArgumentParser(description="REPE Analysis Tool")
    
    parser.add_argument("--model_name_or_path", type=str, required=True,
                      help="Path to pretrained model")
    parser.add_argument("--dataset_path", type=str, required=True,
                      help="Root directory for datasets")
    parser.add_argument("--reference_ds", type=str, required=True,
                      help="Reference dataset name")
    parser.add_argument("--na_label", type=str, required=True,
                      help="NA label")
    parser.add_argument("--ref_lang", type=str, required=False,
                      help="lang mode")
    parser.add_argument("--ntrain", type=int, required=False,default=None,
                      help="Reference dataset name")
    parser.add_argument("--target_ds", type=str, required=True,
                      help="Target dataset name")
    parser.add_argument("--tar_lang", type=str, required=False,
                      help="lang mode")
    parser.add_argument("--reader_path", type=str, required=False,default =None,
                      help="Inference batch size")
    parser.add_argument("--batch_size", type=int, default=2,
                      help="Inference batch size")
    parser.add_argument("--save_dir", type=str, default="results",
                      help="Directory to save readers")

    args = parser.parse_args()

    # 初始化分析器
    analyzer = NAITAnalyzer(args)
    analyzer.prepare_reference()
    if args.reader_path:
        analyzer.load_reader(reader_path = args.reader_path)
    else:
        analyzer.prepare_reference()
        analyzer.build_reader()
        analyzer.save_reader(save_dir=args.save_dir)
    analyzer.analyze_target()
    analyzer.save_results(output_dir=args.save_dir)
    print("Analysis completed successfully!")