import os
import gc
import json
import pickle
from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import transformers
from transformers import HfArgumentParser
from peft import LoraConfig, TaskType, prepare_model_for_kbit_training
from sft_dataset import make_supervised_data_module
from base_dataset import NewData, KshotDataset
from load_model import MyAutoModelForCausalLM
from peft_model import MyPeftModelForCausalLM
from constants import *
from mytrainer import MyTrainer

# ================== Args ==================
@dataclass
class Args:
    model_name_or_path: str = field(default=MODEL_IDENTIFIER)
    loss_type: str = field(default='sft+distance')
    use_sparse_attention: bool = field(default=False)
    num_test: bool = field(default=False)
    parameter_efficient_mode: str = field(default='lora+prompt-tuning')
    lora_module: str = field(default='mlp')
    debug_mode: bool = field(default=False)
    # method args
    distance: str = field(default='prob')
    add_tokens: bool = field(default=True)          # 仅指“是否使用扩展词表保存逻辑”，不再与 ANC 绑定
    add_soft_prompts: bool = field(default=False)
    num_cali_types: int = field(default=10)
    num_prefix: int = field(default=3)
    num_cali: int = field(default=3)
    use_demo: bool = field(default=False)
    k_shots: int = field(default=4)
    position: str = field(default='right')
    weight: float = field(default=0.1)              # CE 与置信度监督的加权系数（Trainer 中用）
    # data args
    raw_data_path: str = field(default=DATASET_PATH)
    dataset: str = field(default='triviaqa_brief')
    processed_data: str = field(default='')
    exp_name: str = field(default='')
    # 控制是否只评估
    only_eval: bool = field(default=False)

@dataclass
class CustomTrainingArgs(transformers.TrainingArguments):
    per_device_train_batch_size: int = field(default=8)
    num_train_epochs: int = field(default=3)
    cache_dir: Optional[str] = field(default=None)
    output_dir: Optional[str] = field(default='')
    overwrite_output_dir: Optional[bool] = field(default=True)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=512)
    remove_unused_columns: Optional[bool] = field(default=False)
    resume: Optional[bool] = field(default=False)
    int8_training: Optional[bool] = field(default=False)
    load_in_8bit: Optional[bool] = field(default=False)
    load_in_16fp: Optional[bool] = field(default=False)
    logging_steps: Optional[int] = field(default=1)
    report_to: Optional[str] = field(default='wandb')
    run_name: str = field(default='default')
    save_strategy: Optional[str] = field(default='no')

# ================== Utils ==================
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.4f}")

def enable_prompt_tuning(model):
    # 去 ANC 后，可能没有 new_embedding/new_linear；安全判断
    inp_emb = model.get_input_embeddings()
    if hasattr(inp_emb, "new_embedding"):
        inp_emb.new_embedding.weight.requires_grad = True
    out_emb = model.get_output_embeddings()
    if hasattr(out_emb, "new_linear"):
        out_emb.new_linear.weight.requires_grad = True

# ================== Confidence Head ==================
class ConfidenceHead(nn.Module):
    def __init__(self, in_dim: int, hidden_ratio: float = 0.5, p: float = 0.1):
        super().__init__()
        mid = max(32, int(in_dim * hidden_ratio))
        self.mlp = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, mid),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(mid, 1),
        )
        # 初始输出接近 0.5
        nn.init.zeros_(self.mlp[-1].weight)
        nn.init.zeros_(self.mlp[-1].bias)

    def forward(self, x):
        # x: [B, H] -> [B, 1] in (0,1)
        return torch.sigmoid(self.mlp(x)).clamp(1e-4, 1 - 1e-4)

def attach_conf_head(model, in_dim):
    if not hasattr(model, "conf_head"):
        model.conf_head = ConfidenceHead(in_dim).to(next(model.parameters()).device)

def save_conf_head(model, save_dir):
    if hasattr(model, "conf_head"):
        outp = os.path.join(save_dir, 'confidence_head.pt')
        torch.save(model.conf_head.state_dict(), outp)
        print(f"[train.py] saved confidence head -> {outp}")

def load_conf_head_if_exists(model, load_dir, in_dim):
    attach_conf_head(model, in_dim)
    path = os.path.join(load_dir, "confidence_head.pt")
    if os.path.exists(path):
        state = torch.load(path, map_location="cpu")
        model.conf_head.load_state_dict(state, strict=True)
        model.conf_head.to(next(model.parameters()).device)

# ================== Main Train ==================
def train(args: Args, training_args: CustomTrainingArgs):
    # ====== exp / tokenizer ======
    print(args.position)
    args.exp_name = f'{args.dataset}_{args.exp_name}_sft_{int(args.weight*100)}'
    training_args.run_name = args.exp_name
    if args.debug_mode:
        training_args.max_steps = 10
    print(f'reg weight:{args.weight} \nexp_name:{args.exp_name}')

    model_name = INDENTIFIER2NAME[args.model_name_or_path]
    # training_args.output_dir = os.path.join(training_args.output_dir, args.dataset, args.exp_name)
    ds_id = os.path.basename(os.path.normpath(args.dataset)) if os.path.isabs(args.dataset) else args.dataset
    training_args.output_dir = os.path.join(training_args.output_dir, ds_id, args.exp_name)

    os.makedirs(training_args.output_dir, exist_ok=True)
    print("[train.py] final output_dir =", training_args.output_dir)
    os.makedirs(training_args.output_dir, exist_ok=True)
    print(f'The model will be saved in {training_args.output_dir}')
    print(f'base model: {args.model_name_or_path}')

    if 'llama' in args.model_name_or_path or 'alpaca' in args.model_name_or_path:
        tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)

    tokenizer.model_max_length = training_args.model_max_length
    # 更稳妥的 pad 设置：无 pad 时用 eos 作为 pad，并同步 pad_token_id
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is None:
            tokenizer.add_special_tokens({"eos_token": "</s>"})
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # 只在需要时添加新特殊 token；这里我们没有 ANC，不强制新增
    special_tokens_dict = {}
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = "</s>"
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)

    # soft-prompt tokens（如果你开启了 add_soft_prompts）
    prompt_text = {'prefix': ''}
    special_tokens_list = []
    initialize_words_list = []

    if args.add_soft_prompts:
        type_token_list = [f'type{i}' for i in range(1, args.num_cali_types + 1)]
        for k in type_token_list:
            prompt_text[k] = ''
        for k in prompt_text:
            text = ''
            if k == 'prefix':
                num_types = args.num_prefix
            else:
                num_types = args.num_cali
            for i in range(num_types):
                token_name = f'<{k}_{i}>'
                special_tokens_list.append(token_name)
                initialize_words_list.append(k)
                text += ' ' + token_name
            prompt_text[k] = text

    print(f'new tokens: {special_tokens_list}')
    num_new_tokens += tokenizer.add_tokens(special_tokens_list)
    prompt_tokens = tokenizer.convert_tokens_to_ids(special_tokens_list)
    initialize_tokens = tokenizer.convert_tokens_to_ids(initialize_words_list)
    assert len(prompt_tokens) == len(initialize_tokens)

    # ====== data ======
    sys_len = len(tokenizer(BRIEF_INSTRUCTION)['input_ids'])
    data_path = os.path.join(DATASET_PATH, args.dataset, model_name, args.processed_data)
    with open(data_path, 'rb') as f:
        raw_data = pickle.load(f)
    train_data = raw_data[:1000] if args.debug_mode else raw_data
    train_dataset = NewData(args, train_data, 'train', prompt_text, position=args.position, add_soft_prompts=args.add_soft_prompts)
    print(train_dataset[0])

    test_path = os.path.join(DATASET_PATH, args.dataset, model_name, 'data_with_answer.pkl')
    # test_path = '/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief/llama2-7b/data_with_answer.pkl'
    with open(test_path, 'rb') as f:
        all_data = pickle.load(f)
    test_data = all_data['test']
    test_dataset = NewData(args, test_data, 'test', prompt_text, position=args.position, add_soft_prompts=args.add_soft_prompts)

    if args.use_demo:
        train_dataset = KshotDataset(train_dataset, train_dataset, k=args.k_shots)

    data_module = make_supervised_data_module(
        tokenizer, train_dataset, test_dataset,
        prompt_tokens, args.use_sparse_attention,
        training_args.remove_unused_columns, max_num_eval=len(test_dataset)
    )

    # ====== build & train ======
    if not args.only_eval:
        if training_args.load_in_16fp or training_args.int8_training:
            model = MyAutoModelForCausalLM.from_pretrained(
                n_tokens=num_new_tokens,
                initialize_tokens=None,
                sparse=args.use_sparse_attention,
                parameter_efficient_mode=args.parameter_efficient_mode,
                pretrained_model_name_or_path=args.model_name_or_path,
                cache_dir=training_args.cache_dir, torch_dtype=torch.float16,
                device_map="auto", load_in_8bit=training_args.int8_training,
                offload_folder="offload", offload_state_dict=True,
            )
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
        else:
            model = MyAutoModelForCausalLM.from_pretrained(
                n_tokens=num_new_tokens,
                initialize_tokens=None,
                sparse=args.use_sparse_attention,
                parameter_efficient_mode=args.parameter_efficient_mode,
                pretrained_model_name_or_path=args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                device_map="auto", offload_folder="offload", offload_state_dict=True,
            )

        # LoRA (& prompt-tuning)
        if 'lora' in args.parameter_efficient_mode:
            target_modules = []
            if args.lora_module == 'mlp' and 'llama' in model_name:
                target_modules += ["gate_proj", "up_proj", "down_proj"]
            elif args.lora_module == 'mlp' and 'phi' in model_name:
                target_modules += ["gate_up_proj", "down_proj"]
            elif args.lora_module == 'atten':
                target_modules += ["q_proj", "k_proj", "v_proj", "o_proj"]
            else:
                raise NotImplementedError

            peft_config = LoraConfig(
                r=16, lora_alpha=16, target_modules=target_modules,
                lora_dropout=0.05, bias="none", inference_mode=False,
                task_type=TaskType.CAUSAL_LM
            )
            model = MyPeftModelForCausalLM(model, peft_config, add_tokens=args.add_tokens)

            if "prompt-tuning" in args.parameter_efficient_mode:
                enable_prompt_tuning(model.base_model.model)
        else:
            raise NotImplementedError

        # —— 挂置信度 head ——
        model.config.output_hidden_states = True
        hsz = getattr(model.config, "hidden_size", None) or getattr(model.config, "n_embd", None)
        attach_conf_head(model, hsz)
        for p in model.conf_head.parameters():
            p.requires_grad = True

        print_trainable_parameters(model)

        
        # === ensure confidence head is attached/loaded BEFORE training so it exists at save time ===
        _hsz = getattr(model.config, "hidden_size", None) or getattr(model.config, "n_embd", None)
        load_conf_head_if_exists(model, training_args.output_dir, _hsz)
        if not hasattr(model, "conf_head"):
            attach_conf_head(model, _hsz)
        print(f"[train.py] conf_head present: {{hasattr(model, 'conf_head')}}; will save to {training_args.output_dir}/confidence_head.pt")        
        trainer = MyTrainer(custom_args=args, ref_model=None, model=model, tokenizer=tokenizer, args=training_args, **data_module)
        trainer.train()

        trainer.save_state()
        trainer.save_model(output_dir=training_args.output_dir)   # 保存 peft + base
        save_conf_head(model, training_args.output_dir)           # 单独保存置信度 head
        with open(f'{training_args.output_dir}/prompt_text.json', 'w') as f:
            json.dump(prompt_text, f, indent=4)

        del trainer, model
        gc.collect()
        torch.cuda.empty_cache()
        print('finished training')

    # ====== Evaluation ======
    print('Start Evaluation')
    model_path = training_args.output_dir
    print("model_path", model_path)
    input_embedding_file = None
    output_embedding_file = None
    if os.path.exists(os.path.join(model_path, 'input_embeddings.pt')):
        input_embedding_file = os.path.join(model_path, 'input_embeddings.pt')
    if os.path.exists(os.path.join(model_path, 'output_embeddings.pt')):
        output_embedding_file = os.path.join(model_path, 'output_embeddings.pt')

    if training_args.load_in_8bit:
        quantization_config = transformers.BitsAndBytesConfig(
            llm_int8_enable_fp32_cpu_offload=getattr(args, "enable_cpu_offload", False))
        model = MyAutoModelForCausalLM.from_pretrained(
            n_tokens=num_new_tokens,
            input_embedding_file=input_embedding_file,
            output_embedding_file=output_embedding_file,
            sparse=args.use_sparse_attention,
            prompt_tokens=prompt_tokens,
            pretrained_model_name_or_path=model_path,
            parameter_efficient_mode=args.parameter_efficient_mode,
            cache_dir=training_args.cache_dir,
            device_map="auto", load_in_8bit=True,
            offload_folder="offload", offload_state_dict=True,
            quantization_config=quantization_config
        )
    else:
        model = MyAutoModelForCausalLM.from_pretrained(
            n_tokens=num_new_tokens,
            input_embedding_file=input_embedding_file,
            output_embedding_file=output_embedding_file,
            sparse=args.use_sparse_attention,
            prompt_tokens=prompt_tokens,
            pretrained_model_name_or_path=model_path,
            parameter_efficient_mode=args.parameter_efficient_mode,
            cache_dir=training_args.cache_dir,
            device_map="auto", torch_dtype=torch.float32,
            offload_folder="offload", offload_state_dict=True
        )

    if 'lora' in args.parameter_efficient_mode:
        model = MyPeftModelForCausalLM.from_pretrained(model, model_path, load_embeddings=True, n_tokens=num_new_tokens)

    # —— 评估：加载置信度 head —— 
    hsz = getattr(model.config, "hidden_size", None) or getattr(model.config, "n_embd", None)
    load_conf_head_if_exists(model, model_path, hsz)
    model.eval()

    trainer = MyTrainer(custom_args=args, model=model, tokenizer=tokenizer, args=training_args, **data_module)
    trainer.evaluate()


if __name__ == "__main__":
    parser = HfArgumentParser((Args, CustomTrainingArgs))
    args, training_args = parser.parse_args_into_dataclasses()
    train(args, training_args)
