import os
os.environ["WANDB_MODE"] = "disabled"
from transformers import Qwen2ForCausalLM, Qwen2Tokenizer,Qwen2Model,Trainer, TrainingArguments, AutoTokenizer, Qwen3Model
import torch.nn as nn
import torch
from model_qwen3.model_span import Qwen_relevence_atten_model_span
from model_qwen3.model_span_cls import Qwen_relevence_atten_model_span as Qwen_relevence_atten_model_span_cls
from model_qwen3.model_list_ablation_17b import Qwen_relevence_atten_model
from dataset import qwen_dataset,my_data_collator,Bert_dataset,qwen_dataset_anker,data_list_collator,Bert_dataset_list, data_list_collator_tacred

from data.dataset_tacred import qwen_dataset_list
from data.dataset_retacred import qwen_dataset_list as qwen_dataset_list_retacred, data_list_collator_retacred
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
from scipy.special import expit ,softmax
import argparse
import numpy as np
import json
import time
import inspect
# from peft import get_peft_model, LoraConfig, TaskType

config_file = "config/qwen.json"
with open(config_file, 'r',encoding='utf-8') as f:
    config = json.load(f)
print("Config loaded:", config)


mode = "dataset_list_addtype"
# # mode = "span_only"

def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def compute_auc(scores, labels):
    """
    计算 AUC（ROC 曲线下的面积）
    
    参数：
    scores (list): 预测分数，取值范围 [0,1]
    labels (list): 真实标签，取值为 0 或 1

    返回：
    float: 计算得到的 AUC 值
    """
    # 将 scores 和 labels 按照 scores 从大到小排序
    sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
    sorted_labels = [labels[i] for i in sorted_indices]

    # 计算正负样本数
    P = sum(labels)  # 正样本数量
    N = len(labels) - P  # 负样本数量

    if P == 0 or N == 0:
        return 1.0  # 只有单一类别，AUC 设为 1.0

    # 计算 TPR 和 FPR
    TPR = [0]  # 真阳性率 (TP / P)
    FPR = [0]  # 假阳性率 (FP / N)
    TP = 0
    FP = 0

    for label in sorted_labels:
        if label == 1:
            TP += 1
        else:
            FP += 1
        TPR.append(TP / P)
        FPR.append(FP / N)

    # 使用梯形法则计算 AUC
    auc = 0.0
    for i in range(1, len(TPR)):
        auc += (FPR[i] - FPR[i - 1]) * (TPR[i] + TPR[i - 1]) / 2  # 梯形面积计算

    return auc

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # predicate_class = torch.argmax(logits, dim=-1)
    
    predicate_class = np.argmax(logits, axis=1)  # 获取每个样本的预测类别
    
    # correct = [pred in true_labels for pred, true_labels in zip(predicate_class, labels)]
    # accuracy = np.mean(correct)
    
    # return {"accuracy": accuracy}
    micro_f1 = f1_score(labels, predicate_class, average='micro')
    accuracy = accuracy_score(labels, predicate_class)
    # precision = precision_score(labels, predicate_class, average='macro')
    # recall = recall_score(labels, predicate_class, average='macro')
    return {"accuracy": accuracy, "f1": micro_f1} 

def compute_metrics_2(eval_pred):
    logits, labels = eval_pred
    # labels B X C
    k = labels.sum(-1).astype(int)
    probs = softmax(logits, axis=1) # B x C
    batch_size = logits.shape[0]
    correct_count = 0
    all_count = 0
    for i in range(batch_size):
        # 获取当前样本的概率和对应的 k 值
        sample_probs = probs[i]
        sample_k = k[i]
        # 获取概率最大的 k 个索引
        top_k_indices = np.argsort(sample_probs)[-sample_k:]
        # 获取当前样本的真实标签索引
        true_indices = np.where(labels[i] == 1)[0]
        # 计算正确预测的数量
        correct = np.intersect1d(top_k_indices, true_indices).size
        correct_count += correct
        all_count += sample_k
        # if correct == sample_k:
        #     correct_count += 1
        # correct_count += float(correct / sample_k)
    # accuracy = correct_count / batch_size
    accuracy = correct_count / all_count if all_count > 0 else 0  # 防止除以零
    return {"accuracy": accuracy, "f1": accuracy} 



class CustomTrainer(Trainer):
    def save_model(self, output_dir: str, _internal_call: bool = False):
        """确保保存完整的 Qwen_relevence_atten_model"""
        if self.args.local_rank == 0:
            try:
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
            except Exception as e:
                print(e)
            super().save_model(output_dir, _internal_call=_internal_call)  # 先保存默认的
            torch.save(self.model.state_dict(), f"{output_dir}/qwen_model.pth")


# 5. 使用PEFT 的 LoRA 配置
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,      # 训练模式
    r=32,                      # LoRA rank
    lora_alpha=32,            
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]  # 要降秩的模块, 根据模型结构调整，OPT一般是q_proj/v_proj
)


if __name__ == "__main__":
    # 参数解析器
    parser = argparse.ArgumentParser(description="Train a Qwen model.")
    # parser.add_argument("--mode", type=str, default="baseline", help="Model mode: baseline, baseline_span, or token + attn + lasttoken")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    # parser.add_argument("--log_id", type=str, default="Qwenpoint_line_surface_LASTOKEN_SPAN_ADD_ATTN_seed42_4096_datasetlist_add2attn_seqlen256_embattnmlp_leftrightavg_3e-5_constantwormup_embcat_catcls_addtype_woinstruction_fulldataset", help="Log ID for saving model and logs")
    parser.add_argument("--log_id", type=str, default="span_cls_noins_LAST_seed42_4096_datasetlist_3e-5_constantwormup_notype_realwoinstruction_fulldataset", help="Log ID for saving model and logs")
    parser.add_argument("--apply_mask", type=str, default="False", help="Apply mask to the entity embeddings")
    parser.add_argument("--forward_method", type=str, default="fhead_add_ftail", help="Forward method for the model choice:[fhead_add_ftail,fhead_sub_ftail,fhead_add_tail,fhead_sub_tail]")
    parser.add_argument("--add_prompt", type=bool,default=False ,help="Use prompt for training")
    parser.add_argument("--dataset", type=str, default="tacred", help="Dataset to use for training ,choice = [tacred,tacrev,re-tacred]")
    parser.add_argument("--mode", type=str, default="dataset_list_addtype", help="dataset_list_addtype, span_only")
    parser.add_argument("--model_name", type=str, default="qwen3_0_6b", help="dataset_list_addtype:choice = [qwen25_0_5b,qwen3_0_6b,qwen3_1_7b,qwen3_4b_instruct]")
    # args = parser.parse_args()
    args = parser.parse_args()
    # 设置随机种子
    set_seed(args.seed)
    # 根据参数设置模型
    apply_mask = args.apply_mask
    forward_method = args.forward_method
    model_name = args.model_name
    log_id = model_name + args.log_id + "_seed" + str(args.seed) + "_applymask" +str(args.mode) +"_mode"+ str(args.apply_mask) + "_forwardmethod" + args.forward_method + args.dataset + "_datasetv3char"
    dataset = args.dataset
    print(dataset)
    if dataset == 're-tacred':
        num_class = 40
    else:
        num_class = 42
    mode = args.mode
    
    print(f"cur mode is {mode}")
    # mode = "dataset_list"
    # mode = "dataset_list_addtype"
    # mode = "span_only"
    if mode == "token + attn + lasttoken" or mode == "add_chat" or mode == "dataset_list":
        model = Qwen_relevence_atten_model(1024 * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method) # best 
        tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    elif mode == "dataset_list_addtype":
        if model_name == "qwen25_0_5b":
            llm_model = Qwen3Model.from_pretrained(
                    "Qwen/Qwen2.5-0.5B-Instruct",
                    attn_implementation="eager"
                ).to(torch.bfloat16)
            # lora
            # llm_model = get_peft_model(llm_model, lora_config)
            hidden_dim = llm_model.config.hidden_size
            model = Qwen_relevence_atten_model(hidden_dim * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method,llm = llm_model) # best method 
            tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

        elif model_name == "qwen3_0_6b":
            llm_model = Qwen3Model.from_pretrained(
                    "Qwen/Qwen3-0.6B",
                    attn_implementation="eager"
                ).to(torch.bfloat16)
            # lora
            # llm_model = get_peft_model(llm_model, lora_config)
            hidden_dim = llm_model.config.hidden_size
            model = Qwen_relevence_atten_model(hidden_dim * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method,llm = llm_model) # best method 
            tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

        
        if model_name == "qwen3_1_7b":
            llm_model = Qwen3Model.from_pretrained(
                    # "Qwen/Qwen3-0.6B",
                    "Qwen/Qwen3-1.7B",
                    attn_implementation="eager"
                ).to(torch.bfloat16)
            # lora
            # llm_model = get_peft_model(llm_model, lora_config)
            hidden_dim = llm_model.config.hidden_size
            ablation_cls = False
            if not ablation_cls:
                model = Qwen_relevence_atten_model(hidden_dim * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method,llm = llm_model) # best method 
            else:
                model = Qwen_relevence_atten_model_span_cls(hidden_dim * 3 , 1024 *4 , num_class, llm = llm_model) # best method 
            # model = Qwen_relevence_atten_model(2048 * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method) # best method 
            # model = Qwen_relevence_atten_model(896 * 2 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method) # best method SPAN + SPAN
            tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
        
        elif model_name == "qwen3_4b_instruct":
            llm_model = Qwen3Model.from_pretrained(
                    # "Qwen/Qwen3-0.6B",
                    "Qwen/Qwen3-4B-Instruct-2507",
                    attn_implementation="eager"
                ).to(torch.bfloat16)
            # lora
            # llm_model = get_peft_model(llm_model, lora_config)
            hidden_dim = llm_model.config.hidden_size
            model = Qwen_relevence_atten_model(hidden_dim * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method,llm = llm_model) # best method 
            tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
        elif model_name == "qwen3_8b":
            llm_model = Qwen3Model.from_pretrained(
                    # "Qwen/Qwen3-0.6B",
                    "Qwen/Qwen3-8B",
                    attn_implementation="eager"
                ).to(torch.bfloat16)
            # lora
            # llm_model = get_peft_model(llm_model, lora_config)

            hidden_dim = llm_model.config.hidden_size
            model = Qwen_relevence_atten_model(hidden_dim * 3 , 1024 *4 , num_class, apply_mask = apply_mask,forward_method = forward_method,llm = llm_model) # best method 
            tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")        
        
            # model = get_peft_model(model, lora_config)
    elif mode == "span_only":
        # model = Qwen_relevence_atten_model_span(896 * 2 , 1024*4, num_class) # span only
        model = Qwen_relevence_atten_model_span(1024 * 2 , 1024*4, num_class) # span only
        # model = Qwen_relevence_atten_model_span(2048 * 2 , 1024*4, num_class) # span only
        tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    print(model)
    model.embedding.use_cache = False

    forward_src = inspect.getsource(model.forward)
    print(forward_src)


    if mode == "dataset_list_addtype" or mode == "span_only":
        # add instruction
        add_instruction = True
        # add_instruction = False

        if add_instruction:
            if dataset == "tacred":
                relevence_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/train.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/train_data_cache_inst.json")
                eval_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/dev.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/valid_data_cache_inst.json")
                test_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/test.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/test_data_cache_inst.json")

            elif dataset == "tacrev":
                relevence_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/train.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/train_data_cache_inst.json")
                eval_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacrev/dev.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacrev/valid_cache_inst.json")
                test_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacrev/test.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacrev/test_cache_inst.json")

            elif dataset == "re-tacred":
                relevence_dataset = qwen_dataset_list_retacred("./dataset/tacrev_data4qwen_fullchar/re_tacred/train.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/re_tacred/train_cached_inst.json")
                eval_dataset = qwen_dataset_list_retacred("./dataset/tacrev_data4qwen_fullchar/re_tacred/dev.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/re_tacred/dev_cached_inst.json")
                test_dataset = qwen_dataset_list_retacred("./dataset/tacrev_data4qwen_fullchar/re_tacred/test.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/re_tacred/test_cached_inst.json")
            
        else:
            if dataset == "tacred":
                relevence_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/train.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/train_data_cache_noinst.json")
                eval_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/dev.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/valid_data_cache_noinst.json")
                test_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/test.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/test_data_cache_noinst.json")

            elif dataset == "tacrev":
                relevence_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacred/train.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacred/train_data_cache_noinst.json")
                eval_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacrev/dev.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacrev/valid_cache_noinst.json")
                test_dataset = qwen_dataset_list("./dataset/tacrev_data4qwen_fullchar/tacrev/test.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/tacrev/test_cache_noinst.json")

            elif dataset == "re-tacred":
                relevence_dataset = qwen_dataset_list_retacred("./dataset/tacrev_data4qwen_fullchar/re_tacred/train.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/re_tacred/train_cache_noinst.json")
                eval_dataset = qwen_dataset_list_retacred("./dataset/tacrev_data4qwen_fullchar/re_tacred/dev.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/re_tacred/dev_cache_noinst.json")
                test_dataset = qwen_dataset_list_retacred("./dataset/tacrev_data4qwen_fullchar/re_tacred/test.json",tokenizer= tokenizer,cache_path="./dataset/tacrev_data4qwen_fullchar/re_tacred/test_cache_noinst.json")
    
    # 获取当前日期
    current_date = time.strftime("%Y%m%d", time.localtime())
    
    if dataset == "re-tacred":
        batch_size = config['batch_size'] //4
    else:
        batch_size = config['batch_size'] //2
    
    if model_name =='qwen3_4b_instruct':
        batch_size = batch_size // 2
    elif model_name == 'qwen3_8b':
        batch_size = batch_size // 2
        # batch_size = 4

    train_arg = TrainingArguments(
        output_dir=f'./results/{current_date}_{log_id}',
        num_train_epochs=10,
        # per_device_train_batch_size=config['batch_size']  ,
        per_device_train_batch_size=batch_size  ,
        # per_device_eval_batch_size=config['batch_size']  ,
        per_device_eval_batch_size=batch_size  ,
        # learning_rate=   4e-5,
        learning_rate=   3e-5,
        # learning_rate=   1e-4,
        # lr_scheduler_type="cosine",
        lr_scheduler_type="constant",
        save_total_limit=2,
        warmup_steps=1000,
        weight_decay=0.01,
        logging_dir=f'./logs/{current_date}_{log_id}',
        logging_steps=10,
        remove_unused_columns=False,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end = True,
        bf16 = True,
        metric_for_best_model="accuracy",
        dataloader_num_workers=16  # 这里设置worker的数量，数字根据机器的CPU核数调整
    )
    data_collator = data_list_collator_tacred if mode =="dataset_list" or mode == "bert_dataset_list" or mode == "span_only"  or mode == "dataset_list_addtype" else my_data_collator
    if dataset == 're-tacred':
        data_collator = data_list_collator_retacred
    trainer = CustomTrainer(model = model,
                    args= train_arg,
                    tokenizer= tokenizer,
                    train_dataset=relevence_dataset,
                    eval_dataset=eval_dataset,
                    # data_collator=my_data_collator,
                    data_collator=data_collator,
                    compute_metrics=compute_metrics_2,
                    )
    trainer.train()
    test_results = trainer.evaluate(test_dataset)
    print("Test results:", test_results)
    