# Standard library imports
import os
# Set environment variables
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import gc
import json

# Third-party library imports
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
import pandas as pd
from tqdm import tqdm
import argparse

# Local application/library specific imports
from src.model import load_model, llama_model_path
from src.dataset import load_dataset
from src.utils.evaluate import eval_funcs
from src.utils.ckpt import _save_checkpoint, _reload_best_model
from src.utils.collate import collate_fn
from src.utils.seed import seed_everything
from src.utils.lr_schedule import adjust_learning_rate

def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tune a graph-based language model.")

    # General settings
    parser.add_argument("--model_name", type=str, default='graph_llm', help="Name of the model.")
    parser.add_argument("--project", type=str, default="project_g_retriever", help="Project name.")
    parser.add_argument("--seed", type=int, default=1, help="Random seed for reproducibility.")

    # Dataset and Learning settings
    parser.add_argument("--dataset", type=str, default='webqsp'.lower(), help="Dataset to use expla_graphs scene_graphs webqsp.")
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
    parser.add_argument("--wd", type=float, default=0.05, help="Weight decay (L2 regularization).")
    parser.add_argument("--patience", type=int, default=2, help="Patience for early stopping.")

    # Model Training
    parser.add_argument("--do_train", type=int, default=1, help="If train.")
    parser.add_argument("--do_eval", type=int, default=1, help="If eval.")
    parser.add_argument("--batch_size", type=int, default=2, help="Batch size for training.")
    parser.add_argument("--grad_steps", type=int, default=4, help="Number of gradient accumulation steps.")

    # Learning Rate Scheduler
    parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs.")
    parser.add_argument("--warmup_epochs", type=float, default=1, help="Number of warmup epochs.")

    # Inference
    parser.add_argument("--eval_batch_size", type=int, default=2, help="Batch size for evaluation.")

    # LLM related
    parser.add_argument("--llm_model_name", type=str, default='7b', help="Name of the LLM model.")
    parser.add_argument("--llm_model_path", type=str, default='', help="Path to the LLM model.")
    parser.add_argument("--llm_frozen", type=str, default='false', help="Whether to freeze the LLM parameters (True/False).")
    parser.add_argument("--llm_num_virtual_tokens", type=int, default=10, help="Number of virtual tokens for the LLM.")
    parser.add_argument("--output_dir", type=str, default='output', help="Directory to save outputs.")
    parser.add_argument("--max_txt_len", type=int, default=512, help="Maximum length of input text.")
    parser.add_argument("--max_new_tokens", type=int, default=32, help="Maximum number of new tokens to generate.")

    parser.add_argument("--temperature", type=float, default=0.6, help="Maximum number of new tokens to generate.")

    # GNN related
    parser.add_argument("--gnn_model_name", type=str, default='gt', help="Name of the GNN model.")
    parser.add_argument("--gnn_num_layers", type=int, default=2, help="Number of layers in the GNN.")
    parser.add_argument("--gnn_in_dim", type=int, default=1024, help="Input dimension of the GNN.")
    parser.add_argument("--gnn_hidden_dim", type=int, default=1024, help="Hidden dimension of the GNN.")
    parser.add_argument("--gnn_num_heads", type=int, default=4, help="Number of attention heads in the GNN.")
    parser.add_argument("--gnn_dropout", type=float, default=0.1, help="Dropout rate for the GNN.")

    args = parser.parse_args()  # 使用空列表 [] 来避免从命令行读取参数

    # Convert string 'True'/'False' to boolean
    args.llm_frozen = args.llm_frozen.lower() == 'true'

    return args

def main():
    args = parse_args()
    print(args)

    seed_everything(seed=args.seed)

    if args.dataset in ['cora', 'citeseer', 'pubmed', 'reddit', 'instagram', 'wikics', 'arxiv']:
        dataset = load_dataset['glbench'](args.dataset)
    else:
        dataset = load_dataset[args.dataset]()
    idx_split = dataset.get_idx_split()

    # Step 2: Build Node Classification Dataset
    train_dataset = [dataset[i] for i in idx_split['train']]
    val_dataset = [dataset[i] for i in idx_split['val']]
    test_dataset = [dataset[i] for i in idx_split['test']]


    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)

    # Step 3: Build Model
    args.llm_model_path = llama_model_path[args.llm_model_name]
    model = load_model[args.model_name](graph_type=dataset.graph_type, args=args, init_prompt=dataset.prompt)

    # Step 4 Set Optimizer
    params = [p for _, p in model.named_parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(
        [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
        betas=(0.9, 0.95)
    )
    trainable_params, all_param = model.print_trainable_params()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")

    if args.do_train:
        num_training_steps = args.num_epochs * len(train_loader)
        progress_bar = tqdm(range(num_training_steps), desc="Training", leave=True, miniters=100)  # 训练进度条
        best_val_loss = float('inf')
        lr = 0
        for epoch in range(args.num_epochs):

            model.train()
            epoch_loss, accum_loss = 0., 0.
            optimizer.zero_grad()

            for step, batch in enumerate(train_loader):
                loss = model(batch)
                # loss = loss1 + loss2# / args.grad_steps  # 每一步先除
                loss.backward()

                if (step + 1) % args.grad_steps == 0:
                    adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
                    clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                    optimizer.step()
                    optimizer.zero_grad()
                # optimizer.step()
                epoch_loss += loss.item()
                accum_loss += loss.item()

                if (step + 1) % args.grad_steps == 0:
                    lr = optimizer.param_groups[0]["lr"]
                    accum_loss = 0.

                progress_bar.update(1)  # 更新训练进度条
                progress_bar.set_postfix({"Train Loss": f"{epoch_loss / (step + 1):.4f}", "lr": lr})  # 显示当前的训练损失

            remain_steps = (step + 1) % args.grad_steps
            if remain_steps != 0:
                clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                optimizer.step()
                optimizer.zero_grad()

            if args.do_eval:
                val_loss = 0.
                eval_output = []
                model.eval()
                with torch.no_grad():
                    for step, batch in enumerate(val_loader):
                        loss = model(batch)
                        val_loss += loss.item()
                    val_loss = val_loss/len(val_loader)
                    print(f"Epoch: {epoch}|{args.num_epochs}: Val Loss: {val_loss}")

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    _save_checkpoint(model, optimizer, epoch, args, is_best=True)
                    best_epoch = epoch

                print(f'Epoch {epoch} Val Loss {val_loss} Best Val Loss {best_val_loss} Best Epoch {best_epoch}')

                if epoch - best_epoch >= args.patience:
                    print(f'Early stop at epoch {epoch}')
                    break
            else:
                _save_checkpoint(model, optimizer, epoch, args, is_best=True)



    # Step 5. Evaluating
    os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True)
    path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{args.seed}.csv'
    print(f'path: {path}')

    model = _reload_best_model(model, args)
    model.eval()
    progress_bar_test = tqdm(range(len(test_loader)), miniters=10)
    with open(path, "w") as f:
        for step, batch in enumerate(test_loader):
            with torch.no_grad():
                output = model.inference(batch)
                df = pd.DataFrame(output)
                for _, row in df.iterrows():
                    f.write(json.dumps(dict(row)) + "\n")
            progress_bar_test.update(1)

    # Step 6. Post-processing & compute metrics
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'reddit', 'instagram', 'wikics', 'arxiv']:
        acc = eval_funcs['glbench'](path)
    else:
        acc = eval_funcs[args.dataset](path)
    print(f'Test Acc {acc}')
    # wandb.log({'Test Acc': acc})


if __name__ == "__main__":
    main()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    gc.collect()
