import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizer,
    TrainingArguments,
    Trainer,
)
import argparse
import json
import os
from datetime import datetime
import numpy as np
import wandb

from utils.data_utils import DataCollatorForSupervisedDataset, load_and_preprocess_it
from block_hadamard_hira import (
    BlockHadamardHiRAConfig,
    get_block_hadamard_hira_model,
)
from setproctitle import setproctitle
setproctitle("python") 

def create_run_directory(args):
    base_dir = "experiments/block_hadamard_hira_arithmetic"
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name = args.model.split('/')[-1]
    name_parts = [
        f"rank_{args.lora_r}",
        f"lr{args.lr}",
        f"alpha_{args.lora_alpha}",
        f"blocks{args.num_blocks_out}",
        args.dataset_split.replace(':','').replace('[','').replace(']','')
    ]
    run_dir = os.path.join(base_dir, model_name, f"{ts}_" + "_".join(name_parts))
    os.makedirs(os.path.join(run_dir, "checkpoints"), exist_ok=True)
    os.makedirs(os.path.join(run_dir, "logs"), exist_ok=True)
    with open(os.path.join(run_dir, "config.json"), 'w') as f:
        json.dump(vars(args), f, indent=2)
    return run_dir


def create_model_tokenizer(args):
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )

    if "llama" in args.model.lower():
        if "Llama-3" in args.model:
            tokenizer = AutoTokenizer.from_pretrained(
                args.model, use_fast=True, model_max_length=args.max_seq_length, padding="max_length"
            )
        else:
            tokenizer = LlamaTokenizer.from_pretrained(
                args.model, use_fast=True, model_max_length=args.max_seq_length, padding="max_length"
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model, use_fast=True, model_max_length=args.max_seq_length, padding="max_length"
        )

    tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
    tokenizer.padding_side = "left"
    return model, tokenizer


def apply_bhra(model, args):
    target_modules = args.target_modules
    if isinstance(target_modules, str):
        target_modules = target_modules.split(',') if ',' in target_modules else [target_modules]

    if args.num_blocks_out != args.num_blocks_in:
        raise ValueError(f"Only square blocks supported: num_blocks_out ({args.num_blocks_out}) must equal num_blocks_in ({args.num_blocks_in})")
    num_blocks = args.num_blocks_out

    cfg = BlockHadamardHiRAConfig(
        r=args.lora_r,
        alpha=args.lora_alpha,
        dropout=args.lora_dropout,
        target_modules=target_modules,
        bias="none",
        init_lora_weights=True,
        num_blocks=num_blocks,
        block_arrangement="square",
        use_fast_inference=True,
    )
    model = get_block_hadamard_hira_model(model, cfg, adapter_name="block_hira_arith")
    return model, cfg


def count_bhra_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    hira = 0
    for name, p in model.named_parameters():
        if p.requires_grad and ("block_lora_A" in name or "block_lora_B" in name):
            hira += p.numel()
    return dict(total=total_params, trainable=trainable, hira=hira)


def finetune(args):
    run_dir = create_run_directory(args)

    wandb_run = wandb.init(
        project="block_hadamard_hira_arithmetic",
        config=vars(args),
        dir=os.path.join(run_dir, "logs"),
        name=os.path.basename(run_dir),
    )
    with open(os.path.join(run_dir, "wandb_run_id.txt"), "w") as f:
        f.write(wandb_run.id)

    model, tokenizer = create_model_tokenizer(args)

    train_dataset = load_and_preprocess_it(tokenizer=tokenizer, args=args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

    model, cfg = apply_bhra(model, args)

    pc = count_bhra_params(model)
    wandb.log({
        "params/total": pc['total'],
        "params/trainable": pc['trainable'],
        "params/bhra": pc['hira']
    })

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    training_args = TrainingArguments(
        output_dir=os.path.join(run_dir, "checkpoints"),
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=0,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type=args.scheduler,
        seed=args.seed,
        report_to="wandb",
        gradient_accumulation_steps=32,
        save_strategy="no",
        bf16=True,
        tf32=False,
        fp16=False,
        logging_steps=1,
        logging_first_step=True,
        logging_dir=os.path.join(run_dir, "logs"),
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
        optimizers=(optimizer, None),
    )

    tokenizer.save_pretrained(os.path.join(run_dir, "tokenizer"))

    trainer.train()

    final_model_path = os.path.join(run_dir, "final_model")
    os.makedirs(final_model_path, exist_ok=True)
    model.save_pretrained(final_model_path)
    tokenizer.save_pretrained(final_model_path)

    return run_dir


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Block Hadamard HiRA Arithmetic Training")

    parser.add_argument("--data_path", type=str, default="meta-math/MetaMathQA")
    parser.add_argument("--dataset_split", type=str, default="train[:50000]")
    parser.add_argument("--dataset_field", type=str, nargs="+", default=["query", "response"])

    parser.add_argument("--model", type=str, default="mistralai/Mistral-7B-v0.1")
    parser.add_argument("--lora_r", type=int, default=32)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0)
    parser.add_argument(
        "--target_modules",
        type=str,
        nargs="+",
        default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        help="Target linear modules to adapt",
    )

    # BHRA 特有：两个 num_block 参数（当前仅支持方阵，需相等）
    parser.add_argument("--num_blocks_out", type=int, default=4, help="Number of output blocks (must equal num_blocks_in)")
    parser.add_argument("--num_blocks_in", type=int, default=4, help="Number of input blocks (must equal num_blocks_out)")

    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--scheduler", type=str, default="cosine")
    parser.add_argument("--warmup_ratio", type=float, default=0.02)
    parser.add_argument("--max_seq_length", type=int, default=512)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="cuda")

    args = parser.parse_args()
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    run_dir = finetune(args)
    print(f"Training finished. Run directory: {run_dir}")
