import os
import sys
import argparse
import json
import logging
from collections import Counter
from safetensors.torch import load_file

import numpy as np
import torch

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    default_data_collator
)

from mixtral_descendant_classification import build_adapted_mixtral
from mixtral_descendant_multi_expert import build_mixtral_multi_expert
from deepseek_descendant import build_deepseek, build_deepseek_classification
from deepseek_descendant_random import build_deepseek_random
from mixtral_descendant_random import build_mixtral_random
from qwen_descendant import build_qwen
from qwen3_descendant_random import build_qwen_random
from qwen3_descendant import build_qwen3
from data_process_new import load_and_process_dataset
from data_process import load_and_process_dataset as load_data



torch.autograd.set_detect_anomaly(True)


logging.basicConfig(
    filename="train_log.txt" if int(os.getenv('LOCAL_RANK', '0')) == 0 else None,
    filemode="w",
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO
)

def print_parameters_with_grad(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")

def get_device():
    if torch_npu is not None:
        return torch.device("npu")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def parse_args():
    parser = argparse.ArgumentParser(description="Training script for model fine-tuning.")

    parser.add_argument('--model_name', type=str, default="mixtral", required=True,
                        help="Model name (e.g., 'mixtral', 'mixtral_multi', 'deepseek').")
    parser.add_argument('--model_path', type=str, default=None, help="Path of model")
    parser.add_argument('--load_model_path', type=str, default=None, help="Path of model checkpoint")
    parser.add_argument('--load_param_path', type=str, default="", help="Path of load parameters")
    parser.add_argument('--n_layer', type=int, default=2, help="Number of layers to use.")
    parser.add_argument('--n_expert', type=int, default=2, help="Number of experts to use.")
    parser.add_argument('--use_residual', action='store_true', help="Use residual connection.")
    parser.add_argument('--grad', action='store_true', help="Enable gradients.")
    parser.add_argument('--use_random_init', action='store_true', help="Randomly initialize model parameters.")
    parser.add_argument('--rank_dim', type=int, default=256, help="the rank in decompose")

    parser.add_argument('--epochs', type=int, default=10, help="Number of epochs.")
    parser.add_argument('--learning_rate', type=float, default=1e-5, help="Learning rate for training.")
    parser.add_argument('--batch_size', type=int, default=4, help="Batch size per device for training and evaluation.")
    parser.add_argument('--output_dir', type=str, default="./results/finetuning_results", help="Directory to save the model.")
    parser.add_argument('--log_dir', type=str, default="./results/finetuning_results", help="Directory to save the log.")
    parser.add_argument('--max_length', type=int, default=256)

    parser.add_argument('--data_name', type=str, default="piqa", required=True,
                        help="Dataset name (e.g.,'arc_easy', 'arc_challenge', 'piqa', 'winogrande', 'hellaswag', 'openbqa').")
    return parser.parse_args()

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = (preds == labels).astype(float).mean()
    return {"eval_accuracy": acc}

class TxtLoggingTrainer(Trainer):
    def log(self, logs, step=None):
        super().log(logs)
        if int(os.getenv('LOCAL_RANK', '0')) == 0:
            logging.info(f"Step: {self.state.global_step}, Logs: {logs}")

def main(args):
    device = get_device()

   
    if not args.model_path:
        if "mixtral" in args.model_name.lower():
            args.model_path = "/ckpt_llm/mixtral"
        elif "deepseek" in args.model_name.lower():
            args.model_path = "/ckpt_llm/deepseek"
        elif "qwen" in args.model_name.lower():
            args.model_path = "/ckpt_llm/qwen3"

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    mname = args.model_name.lower()
    if mname == "qwen":
        train_data, test_data = load_data(
            dataset_name=args.data_name,
            tokenizer=tokenizer,
            max_length=args.max_length
        )
        num_labels = 1
    else:
        train_data, test_data = load_and_process_dataset(
            dataset_name=args.data_name,
            tokenizer=tokenizer,
            max_length=args.max_length
        )
        
        # label_key = 'labels' if 'labels' in train_data else 'label'
        labels = train_data['labels']
        num_labels = len(set(labels))

    print("=================== num label ================== {}".format(num_labels))

    
    model = None
    # mname = args.model_name.lower()
    if mname == "mixtral":
        model = build_adapted_mixtral(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            num_labels=num_labels,
            n_layers=args.n_layer,
            classification=True,
            use_residual=args.use_residual,
            grad=args.grad,
        )
        if args.load_model_path is not None:
            print("=====================================")
            state_dict = load_file(args.load_model_path)
            model.load_state_dict(state_dict, strict=False)
            
    elif mname == "deepseek":
        model = build_deepseek_classification(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            num_labels=num_labels,
            n_layers=args.n_layer,
            grad=args.grad,
            classification=True
        )
        if args.load_model_path is not None:
            print("=====================================")
            state_dict = load_file(args.load_model_path)
            model.load_state_dict(state_dict, strict=False)
            
    elif mname == "qwen":
        model = build_qwen(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            n_layers=args.n_layer,
            grad=args.grad
        )
    elif mname == "qwen_random":
        model = build_qwen_random(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            n_layers=args.n_layer,
            grad=args.grad
        )
    elif mname == "mixtral_multi":
        model = build_mixtral_multi_expert(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            n_layers=args.n_layer,
            num_experts=args.n_expert,
            grad=args.grad
        )
    elif mname == "deepseek_random":
        model = build_deepseek_random(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            n_layers=args.n_layer,
            grad=args.grad,
            classification=True
        )
        if args.load_model_path is not None:
            print("=====================================")
            state_dict = load_file(args.load_model_path)
            model.load_state_dict(state_dict, strict=False)
            
        
    elif mname == "mixtral_random":
        model = build_mixtral_random(
            model_path=args.model_path,
            load_param_path=args.load_param_path,
            n_layers=args.n_layer,
            grad=args.grad,
            num_labels=num_labels
        )
        if args.load_model_path is not None:
            print("=====================================")
            state_dict = load_file(args.load_model_path)
            # model.load_state_dict(state_dict, strict=False)
            missing, unexpected = model.load_state_dict(state_dict, strict=False)
            print("Missing keys:", missing)
            print("Unexpected keys:", unexpected)
           
    elif mname == "mixtral_original":
        from transformers import AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            device_map="cpu",
            trust_remote_code=True,
            torch_dtype=torch.float,
        )
    else:
        raise ValueError(f"Unknown model_name: {args.model_name}")

    print_parameters_with_grad(model)

   
    model = model.float()
    model.to(device)

   
    if mname == "mixtral_multi":
        output_dir = os.path.join(
            args.output_dir,
            f"{args.model_name}-layer{args.n_layer}-expert{args.n_expert}",
            f"{args.data_name}-ep{args.epochs}-lr{args.learning_rate}"
        )
        log_dir = os.path.join(
            args.log_dir,
            f"{args.model_name}-layer{args.n_layer}-expert{args.n_expert}",
            f"{args.data_name}-ep{args.epochs}-lr{args.learning_rate}"
        )
    else:
        output_dir = os.path.join(
            args.output_dir,
            f"{args.model_name}-layer{args.n_layer}",
            f"{args.data_name}-ep{args.epochs}-lr{args.learning_rate}"
        )
        log_dir = os.path.join(
            args.log_dir,
            f"{args.model_name}-layer{args.n_layer}",
            f"{args.data_name}-ep{args.epochs}-lr{args.learning_rate}"
        )

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    training_args = TrainingArguments(
        output_dir=output_dir,
        logging_dir=log_dir,
        report_to=[],
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        weight_decay=0.01,
        logging_steps=10,
        logging_first_step=True,
        metric_for_best_model="eval_accuracy",
        save_total_limit=2,
        fp16=False,
        no_cuda=(device.type == "cpu"),
        local_rank=int(os.getenv('LOCAL_RANK', '-1')),
        load_best_model_at_end=True,
        max_grad_norm=0.5,
        seed=42
    )

    trainer = TxtLoggingTrainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=test_data,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()


    eval_results = trainer.evaluate()
    eval_save_path = os.path.join(log_dir, "evaluation")
    os.makedirs(eval_save_path, exist_ok=True)
    output_file = os.path.join(eval_save_path, "evaluation_results.json")

    if int(os.getenv('LOCAL_RANK', '0')) == 0:
        with open(output_file, "w") as f:
            json.dump(eval_results, f, indent=2)

if __name__ == "__main__":
    args = parse_args()
    main(args)
