import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import TensorDataset, DataLoader
import random
import os
from tqdm import tqdm
import pandas as pd
from sklearn.utils import shuffle
import time
from torch.nn.utils import clip_grad_norm_
import json
import datetime

from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

from deepspeed import get_accelerator
import logging
import deepspeed
import subprocess
import wandb
from sklearn.metrics import r2_score
import numpy as np

import pathlib

import os
import re
import shutil, yaml

import cluster_utils

def parse_filename(filename):
    """
    解析文件名，提取 epoch, step, loss
    :param filename: 文件名字符串
    :return: (epoch, step, loss) 元组
    """
    match = re.match(r'(\d+)-(\d+)-(\d+\.?\d*)', filename)
    if match:
        epoch = int(match.group(1))
        step = int(match.group(2))
        loss = float(match.group(3))
        return epoch, step, loss
    return None


def delete_files_wrt_loss(save_path, max_save_files=5, reverse=True):
    """
    保留 loss 最小的 k 个文件，并删除其他文件
    :param directory: 文件夹路径
    :param k: 保留的文件数量
    """
    files = os.listdir(save_path)
    parsed_files = []

    for file in files:
        parsed = parse_filename(file)
        if parsed:
            parsed_files.append((file, *parsed))

    if not parsed_files:
        return

    # 按 loss 排序
    parsed_files.sort(key=lambda x: x[3], reverse=reverse)
    print(parsed_files)
    # 保留前 k 个文件
    max_save_files = min(max_save_files, len(parsed_files))
    best_files = parsed_files[:max_save_files]

    # 删除其他文件
    for file, _, _, _ in parsed_files[max_save_files:]:
        file_path = os.path.join(save_path, file)
        for subfile in os.listdir(file_path):
            os.remove(os.path.join(file_path, subfile))
        os.rmdir(file_path)

    print(f"Kept {len(best_files)} best files with the smallest loss.")




def read_reaction(path,data_name):
    dataset = Dataset.load_from_disk(os.path.join(path, data_name))
    # raw_data = pd.read_csv(os.path.join(path, data_name, data_name + ".csv"))
    known_yields = dataset['yield']
    known_conditions = dataset['condition']
    reactions = dataset['reaction']
    return known_conditions, known_yields, reactions



def cleanup():
    dist.destroy_process_group()
    
def get_attr_string_from_json_files(json_files_dir):
    """
    convert attrs in json to attr_string
    :param json_files_dir: json files dir
    :return: attr_string
    """
    results = {}
    for file in os.listdir(json_files_dir):
        if not file.endswith(".json"):
            continue
        file_path = os.path.join(json_files_dir, file)
        with open(file_path, "r") as f:
            json_data = json.load(f)
        data_maps = {}
        for data in json_data:
            name = data["name"]
            attrs = list(data.keys())
            att_string = ""
            for attr in attrs:
                if attr != "name":
                    att_string += f"{attr}: {data[attr]}\n"
            data_maps[name] = att_string
        
        results[file.split(".")[0]] = data_maps

    return results

class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, 1),
                        )
        
    def forward(self, x):
        pred = self.predictor(x)
     
        return pred

class LlamaWithLoss(nn.Module):
    def __init__(self, llama, predictor):
        super(LlamaWithLoss, self).__init__()
        self.llama = llama
        self.loss_func = torch.nn.MSELoss()

        self.predictor = predictor
        
    def forward(self, inputs, y, pooling_method='last_token', return_loss=True):
        outputs = self.llama(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        if pooling_method == 'mean':
            embeddings = last_hidden_state.mean(dim=1)  # Mean pooling to get sentence-level embeddings
        elif pooling_method=='last_token':
            embeddings = last_hidden_state[:,-1,:]
        else:
            raise ValueError("pooling_method must be 'mean' or 'last_token'")

        if return_loss:
            pred = self.predictor(embeddings)
            loss = self.loss_func(pred.view(-1),y.view(-1))
            return embeddings, loss
        else:
            pred = self.predictor(embeddings)
            return embeddings, pred

def read_data_from_csv(path):
    data_df = pd.read_csv(path)
    dataset = Dataset.from_dict(data_df)
    return dataset



def train(args, cluster_args):
    if args.local_rank == 0:
        if args.wandb_offline:
            os.environ["WANDB_MODE"] = "offline"
            os.environ["WANDB_DISABLED"] = "true"
        else:
            wandb.init( 
                project=args.project_name,
                name=f"{args.data_name}_cluster_{cluster_args['cluster_loss_lambda']}_B{args.per_device_train_batch_size}_E{args.num_epoch}",
                )

    # Load the model and tokenizer
    pretrained_model_path = args.pretrained_model_path
    num_epoch = args.num_epoch
    batch_size= args.per_device_train_batch_size
    yield_predictor_path = args.yield_predictor_path
    lr=args.lr
    max_length = args.max_length

    data_path=args.data_path
    data_name = args.data_name
    # Save the base model
    lora_adapter_path = args.lora_adapter_path

    load_ds_dir = args.load_ds_dir
    load_ds_ckpt_id = args.load_ds_ckpt_id

    use_lora = args.use_lora

    logging.basicConfig(
        filename=args.log_file,
        level=logging.INFO,
    )
    logger = logging.getLogger()
    
    logger.info(args.latest_method_notes)
    logger.info(cluster_args)

    # args.global_rank = torch.distributed.get_rank()
    get_accelerator().set_device(args.local_rank)
    device = torch.device(get_accelerator().device_name(), args.local_rank)
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    # torch.distributed.init_process_group(backend='nccl')
    deepspeed.init_distributed()

    print('using device', device)


    if use_lora:

        # Define LoRA configuration
        if not os.path.exists(lora_adapter_path):
            if args.load_by_torch:
                print('Load model By Torch...')
                logger.info('Load model...')
                model = cluster_utils.load_model_with_state_by_torch(pretrained_model_path)
            else:
                print('Load model From Pretrain...')
                logger.info('Load model...')
                model = AutoModel.from_pretrained(pretrained_model_path, local_files_only=True, trust_remote_code=True)
                
            tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, local_files_only=True, trust_remote_code=True)
            special_tokens_dict = {}
            if tokenizer.pad_token is None:
                special_tokens_dict['pad_token'] = '[PAD]'
            if tokenizer.eos_token is None:
                special_tokens_dict['eos_token'] = '</s>'

            if special_tokens_dict:
                tokenizer.add_special_tokens(special_tokens_dict)
                model.resize_token_embeddings(len(tokenizer))
            # print(model)
            # logger.info(model)
            # Apply LoRA to the model\
            print('LoRA configuring...')
            logger.info('LoRA configuring...')
            lora_config = LoraConfig(
                r=8,
                lora_alpha=16,
                lora_dropout=0.1,
                target_modules=["q_proj", "v_proj", "k_proj"],
                bias="none",
                task_type=None,
                inference_mode=False,
                init_lora_weights=True,
                fan_in_fan_out=False,
                peft_type="LORA",
                revision=None,
                use_dora=False,
                use_rslora=False,
            )
            model = get_peft_model(model,lora_config)
        else:
            print(f"Load LoRA from {lora_adapter_path}")
            logger.info(f"Load LoRA from {lora_adapter_path}")
            lora_config = PeftConfig.from_pretrained(lora_adapter_path)
            if args.load_by_torch:
                model = cluster_utils.load_model_with_state_by_torch(lora_config.base_model_name_or_path)
            else:
                model = AutoModel.from_pretrained(lora_config.base_model_name_or_path, local_files_only=True)
                
            tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path, local_files_only=True)
            special_tokens_dict = {}
            if tokenizer.pad_token is None:
                special_tokens_dict['pad_token'] = '[PAD]'
            if tokenizer.eos_token is None:
                special_tokens_dict['eos_token'] = '</s>'

            if special_tokens_dict:
                tokenizer.add_special_tokens(special_tokens_dict)
                model.resize_token_embeddings(len(tokenizer))
            model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True)
        model.print_trainable_parameters()  # Print the number of trainable parameters to confirm LoRA is applied
    
    world_size = torch.distributed.get_world_size()
    rank = args.local_rank


    predictor = YieldPredLayer(2048,1024,1).to(device).train() 
    if os.path.exists(yield_predictor_path):
        try:
            predictor.load_state_dict(torch.load(yield_predictor_path))
        except Exception as e:
            print("Error when load predictor : ", e)
            logger.info("Error when load predictor : ", e)
    model = LlamaWithLoss(model, predictor)


    


    # Data
    print('Load data from...', os.path.join(data_path, data_name, 'train.csv'))
    logger.info(f'Load data ...')
    # train_data = Dataset.load_from_disk(os.path.join(data_path, data_name))
    train_data = read_data_from_csv(os.path.join(data_path, data_name, 'train.csv'))
    # DDP sampler
    train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True)
    trainloader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler,num_workers=8)

    # 
    optimizer_params = [
        {'params': model.llama.parameters(), 'lr': lr},
        {'params': model.predictor.parameters(), 'lr': lr*args.mlp_lr_multiplier}
    ]
    optimizer = optim.AdamW(optimizer_params)
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
    
    train_batch_size = args.per_device_train_batch_size * world_size * args.gradient_accumulation_steps
    print("Train_batch_size: ", train_batch_size)
    logger.info(f'Train_batch_size: {train_batch_size}')

    with open(args.deepspeed_config, 'r') as f:
        ds_config = json.load(f)
    ds_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    ds_config['train_batch_size'] = train_batch_size
    ds_config['scheduler']['params']['total_num_steps'] = num_epoch * len(trainloader) / args.gradient_accumulation_steps
    print(f"total effictive steps {ds_config['scheduler']['params']['total_num_steps']}")
    ds_config['scheduler']['params']['warmup_num_steps'] = ds_config['scheduler']['params']['total_num_steps']*0.1
    model, optimizer, _, _ = deepspeed.initialize(
        # args=args,
        model=model,
        optimizer=optimizer,
        config_params=ds_config,
        # model_parameters=all_parameters,
        # dist_init_required=True,
    )

    if load_ds_dir is not None and os.path.exists(load_ds_dir):
        print(f"Load deepspeed checkpoint from {load_ds_dir}")
        logger.info(f'Load deepspeed checkpoint from {load_ds_dir}')
        model.load_checkpoint(load_ds_dir, load_ds_ckpt_id)


    
    if args.run_test:
        test_data = read_data_from_csv(os.path.join(data_path, data_name, 'test.csv'))
        test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False)
        testloader = DataLoader(test_data, batch_size=batch_size, sampler=test_sampler,num_workers=8)

    best_eval_loss = torch.inf
    
    for epoch in range(1,num_epoch+1):
        print(f'Training Epoch {epoch}:')
        logger.info(f'Training Epoch {epoch}:')
        # DDP set epoch
        trainloader.sampler.set_epoch(epoch)

        total_loss = torch.scalar_tensor(0)
        losses = []
        run_time = 0
        for i, batch_data in enumerate(trainloader):
            model.train()
            model.llama.train()
            model.predictor.train()

            start_time = time.time()

            # y = batch_data['yield'].to(torch.float).to(device)
            # reaction = batch_data['reaction']
            # condition = batch_data['condition']
            # prompts = [reaction[k] + condition[k] for k in range(len(reaction))]
            prompts = batch_data['instruction']
            
            y = batch_data['output'].to(torch.float).to(device)
            # add some noise
            noise = torch.randn_like(y).to(torch.float).to(device)
            y += noise
            inputs = tokenizer(prompts, max_length=max_length, padding='longest', truncation=True, return_tensors="pt").to(device)

            # Get embeddings
            if model.fp16_enabled():
                y = y.half() 
            elif model.bfloat16_enabled():
                y = y.bfloat16()
            
            _, pred_loss = model(inputs, y, pooling_method=args.pooling_method, return_loss=True)
            # TODO: Cluster
            distance_loss = torch.tensor(0.0, dtype=torch.float32, device=device)

            # add distance_loss
            loss = cluster_args['cluster_loss_lambda'] * distance_loss + pred_loss if args.use_cluster else pred_loss
            losses.append(loss.cpu().item())
            model.backward(loss)
            model.step()
  

            end_time = time.time()
            run_time += (end_time-start_time)/ (60*60)
            print(f"Rank {rank}, Epoch {epoch}:{i+1}/{len(trainloader)}-step, \tloss:{loss}, \tpred_loss:{pred_loss}, \tdistance_loss:{distance_loss}, \truning time:{end_time-start_time}s, \tleft_time:{(run_time/(i+1)) *(len(trainloader)-i-1)}h ")
            logger.info(f"Rank {rank}, Epoch {epoch}:{i+1}/{len(trainloader)}-step, \tloss:{loss}, \tpred_loss:{pred_loss}, \tdistance_loss:{distance_loss}, \truning time:{end_time-start_time}s, \tleft_time:{(run_time/(i+1)) *(len(trainloader)-i-1)}h ")

            if args.local_rank == 0 and (not args.wandb_offline): 
                wandb.log({'train_loss': loss,'distance_loss': distance_loss, 'pred_loss': pred_loss, 'current_lr': optimizer.param_groups[0]['lr']})

        

   
        if (epoch % args.save_interval == 0) or (epoch == num_epoch) or (epoch == 1):
            # Evaluate model
            model.eval()
            model.llama.eval()
            model.predictor.eval()
            eval_losses = []
            pred_all = []
            target_all = []
            best_r2_score = -1
            print(f"Evaluation ...")
            logger.info(f"Evaluation ...")
            
            if args.run_test:
                with torch.no_grad():
                    testloader.sampler.set_epoch(epoch)
                    for batch_data in tqdm(testloader):
                        prompts = batch_data['instruction']
                        y_true = batch_data['output'].to(torch.float).to(device)

                        inputs = tokenizer(prompts, max_length=max_length, padding='longest', truncation=True, return_tensors="pt").to(device)

                        # Get embeddings
                        if model.fp16_enabled():
                            y_true = y_true.half()
                        if model.bfloat16_enabled():
                            y_true = y_true.bfloat16()
                        embeddings, pred = model(inputs, y_true, return_loss=False)
                        pred_all.append(pred.to(torch.float32).cpu().numpy())
                        target_all.append(y_true.to(torch.float32).cpu().numpy())
                        loss = torch.nn.functional.mse_loss(pred.view(-1), y_true.view(-1))
                        eval_losses.append(loss.cpu().item())

            if len(eval_losses) == 0:
                if args.eval_save_ckpt:
                    ckpt_path = os.path.join(args.save_root, 'eval')
                    model.save_checkpoint(ckpt_path, tag=f'eval_{epoch}')
                    print(f"Model saved to: {ckpt_path}")
                    logger.info(f"Model saved to: {ckpt_path}")
                continue
            pred_all = np.concatenate(pred_all).reshape(-1)
            target_all = np.concatenate(target_all).reshape(-1)

            r2 = r2_score(pred_all.astype(np.float64), target_all.astype(np.float64))
            r2 = torch.tensor(r2).to(torch.cuda.current_device())
            
            avg_eval_loss = sum(eval_losses) / len(eval_losses)
            print(f"Avg Eval Loss: {avg_eval_loss}, R2: {r2}")
            logger.info(f"Avg Eval Loss: {avg_eval_loss}, R2: {r2}")
            dist.all_reduce(r2, op=dist.ReduceOp.SUM)

            if args.local_rank == 0: 
                # print(r2)
                r2 = r2 / world_size
                if not args.wandb_offline:
                    wandb.log({'eval_loss':avg_eval_loss, 'eval R2':r2})
                if avg_eval_loss < best_eval_loss:
                    best_eval_loss = avg_eval_loss
            
                if r2 > best_r2_score:
                    best_r2_score = r2
                    
                    if args.eval_save_ckpt:
                        ckpt_path = os.path.join(args.save_root, 'eval')
                        model.save_checkpoint(ckpt_path, tag=f'eval_{epoch}')
                        print(f"Model saved to: {ckpt_path}")
                        logger.info(f"Model saved to: {ckpt_path}")


        print(f"Avg  Loss on Epoch {epoch}: {sum(losses) / len(losses)}")
        logger.info(f"Avg Loss on Epoch {epoch}: {sum(losses) / len(losses)}")
        

    ckpt_path = os.path.join(args.save_root, 'final')
    # model.save_checkpoint(ckpt_path, tag='final')
    # print(f"Model saved to: {ckpt_path}")

    os.makedirs(args.lora_adapter_save_path, exist_ok=True)
    model.save_checkpoint(args.lora_adapter_save_path, tag=args.data_name + args.latest_method_notes.split(' : ')[0])
    model.llama.save_pretrained(args.lora_adapter_save_path)
    
    with open(os.path.join(args.lora_adapter_save_path, "notes.txt"), "w") as f:
        f.write(args.latest_method_notes)
    
    print(f"LoRA saved to: {args.lora_adapter_save_path}")
    
    
    
    data = read_data_from_csv(
        os.path.join('./data4regression', args.data_name, 'test.csv')
    )
    data_sampler = DistributedSampler(
        data, 
        num_replicas=world_size, 
        rank=rank,
        shuffle=False    
    )
    dataloader = DataLoader(
        data, 
        batch_size=args.per_device_train_batch_size,
        sampler=data_sampler,
        num_workers=8, 
        shuffle=False
    )
    print("Dataset loaded")
    
     # Inference
    y_trues = []
    y_preds = []
    
    total_attr_strings_maps = get_attr_string_from_json_files(cluster_args['cluster_json_dir_path'])
    total_inputs_maps = {}
    for name, attr_strings_maps in total_attr_strings_maps.items():
        inputs_maps = {}
        for item, attr_string in attr_strings_maps.items():
            inputs = tokenizer(attr_string, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
            inputs_maps[item] = inputs
        total_inputs_maps[name] = inputs_maps

    with torch.no_grad():
        # for i, batch in enumerate(tqdm(dataloader,desc=f"{args.split_id}",file=sys.stdout)):
        for i, batch in enumerate(tqdm(dataloader,desc=f"Rank {rank} inference   ")):    
            prompts = batch['instruction'] #获取输入指令
            y = batch['output'].to(torch.float).to(device) #获取y值
            
            # 文本编码，对prompts进行编码，设置最大长度为3000，padding和截断
            inputs = tokenizer(prompts, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
            # 将 inputs 字典中的每一个张量移动到当前进程对应的GPU上，
            # 从而保证数据和模型在同一设备上，避免设备不匹配的错误。
            inputs = {k: v.to(device) for k,v in inputs.items()}
            y = y.to(device)
            
            emb, pred = model(inputs, y, return_loss=False)
            y_preds.append(pred.detach().to(device))
            y_trues.append(y.detach().to(device))
            
    # Gather results from all ranks to rank 0
    if torch.distributed.is_initialized():
        if rank == 0: 
            # rank 0创建长度为Word_size的列表，每个元素是一个0张量
            gathered_preds = [torch.zeros_like(torch.cat(y_preds, dim=0)) for _ in range(world_size)]
            gathered_labels = [torch.zeros_like(torch.cat(y_trues, dim=0)) for _ in range(world_size)]
            
            # Concatenate the results on rank 0
            dist.gather(
                torch.cat(y_preds, dim=0),  # 计算结果拼接为一个一维张量
                gather_list=gathered_preds, # 其他进程的所有数据汇聚与这个列表
                dst=0 # 汇聚到的进程为0进程
            )
            dist.gather(
                torch.cat(y_trues, dim=0), # 拼接为一个一维张量
                gather_list=gathered_labels, # 其他进程的所有数据汇聚与这个列表
                dst=0 # 汇聚到的进程为0进程
            )
            
        else:
            # 非目标进程与目标进程进行通信匹配，数据传输 
            dist.gather(
                torch.cat(y_preds, dim=0),
                gather_list=None,
                dst=0
            )
            dist.gather(
                torch.cat(y_trues, dim=0),
                gather_list=None,
                dst=0
            )
                        
        # # Concatenate the results on rank 0
        # dist.gather(
        #     torch.cat(y_preds, dim=0),  # 计算结果拼接为一个一维张量
        #     gather_list=gathered_preds, # 其他进程的所有数据汇聚与这个列表
        #     dst=0 # 汇聚到的进程为0进程
        # )
        # dist.gather(
        #     torch.cat(y_trues, dim=0), # 拼接为一个一维张量
        #     gather_list=gathered_labels, # 其他进程的所有数据汇聚与这个列表
        #     dst=0 # 汇聚到的进程为0进程
        # )

        # Only rank 0 will have the gathered results
        if rank == 0:
            y_preds_gathered = torch.cat(gathered_preds, dim=0).to(dtype=torch.float32) if len(gathered_preds) > 0 else None
            y_labels_gathered = torch.cat(gathered_labels, dim=0).to(dtype=torch.float32) if len(gathered_labels) > 0 else None

            # Evaluate model
            mse_loss = torch.nn.functional.mse_loss(y_preds_gathered.view(-1), y_labels_gathered.view(-1)).to(dtype=torch.float32)
            mae_loss = torch.nn.functional.l1_loss(y_preds_gathered.view(-1), y_labels_gathered.view(-1)).to(dtype=torch.float32)
            r2 = r2_score(y_labels_gathered.cpu().numpy(), y_preds_gathered.cpu().numpy())
            
            # TODO: Cluster
            distance_loss = torch.tensor(0.0, dtype=torch.float32, device=device)

            logging.info(f"Rank {rank} - MSE Loss: {mse_loss.item()}, MAE Loss: {mae_loss.item()}, R2 Score: {r2}")
            logging.info(f"Rank {rank} - Distance Loss: {distance_loss.item()}\n")
            print(f"Rank {rank} - MSE Loss: {mse_loss.item()}, MAE Loss: {mae_loss.item()}, R2 Score: {r2}, Distance Loss: {distance_loss.item()}")


    # TODO: Still have bug    
    # if dist.get_rank() == 0:                # 只在 rank0 做合并/保存，避免重复
    #     with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0):
    #         peft_model = model.module
    #         merged_model = peft_model.merge_and_unload()
    #         torch.save(merged_model.predictor.state_dict(),
    #                 os.path.join(args.lora_adapter_save_path, "predictor.pt"))
    #         logger.info(f"Predictor saved to : {args.lora_adapter_save_path}")

    # logger.info(f"Model saved to: {ckpt_path}")
    if args.local_rank == 0 and (not args.wandb_offline): 
        wandb.log({'train_loss_on_epoch': total_loss.item() / (i+1)})
        wandb.finish()

        

def main():
    parser = argparse.ArgumentParser(description="Distributed Data Parallel Training")
    
    parser.add_argument("--pretrained_model_path", default='/mnt/cache/Chemllm/reaction_condition_recommendation/src/step1_llama3_8b_0916_yearly_pistachio_ep3')
    # parser.add_argument("--pretrained_model_path", default='/mnt/hwfile/ai4chem/share/jianpeng/llama3_8b_0916_lora_yield_pred_ds')
    parser.add_argument("--lora_adapter_path", default="/mnt/hwfile/ai4chem/share/jianpeng/llama3_8b_0916_lora_yield_pred_ds/lora_adapter")
    parser.add_argument("--yield_predictor_path", default="/mnt/hwfile/ai4chem/share/jianpeng/llama3_8b_0916_lora_yield_pred_ds/predictor.pt")
    parser.add_argument("--num_epoch", type=int, default=2)
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--per_device_train_batch_size", type=int, default=4)
    parser.add_argument("--train_batch_size", type=int, default=2)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--data_path", default='/mnt/petrelfs/chenjianpeng/cjp/LLaMA-Factory/train_regression/data4regression')
    parser.add_argument("--data_name", default='suzuki_miyaura_600')
    parser.add_argument("--save_root", default="/mnt/hwfile/ai4chem/share/jianpeng/llama3_8b_0916_lora_yield_pred_ds")
    parser.add_argument("--base_model_save_path", default="base_model")
    parser.add_argument("--lora_adapter_save_path", default="lora_adapter")
    parser.add_argument('--use_lora', type=int, default=1)
    parser.add_argument('--log_file', type=str, default="training_ds.log")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument('--max_length', type=int, default=3000)
    parser.add_argument('--load_ds_dir', type=str, default=None)
    parser.add_argument('--load_ds_ckpt_id', type=str, default=None)
    parser.add_argument('--mlp_lr_multiplier',type=float, default=10)
    parser.add_argument('--project_name', type=str, default="llama_regression")

    parser.add_argument("--deepspeed_config", type=str, default="ds_config.json")
    parser.add_argument("--save_interval", type=int, default=20)
    # ========= New Add =========
    parser.add_argument("--eval_save_ckpt", type=int, default=0)
    parser.add_argument("--save_lora_adapter", type=int, default=1)
    parser.add_argument("--save_predictor", type=int, default=1)
    parser.add_argument("--pooling_method", type=str, default="last_token", choices=["mean", "last_token"])
    parser.add_argument("--run_test", type=int, default=1)
    parser.add_argument("--wandb_offline", type=int, default=0)
    parser.add_argument("--load_by_torch", type=int, default=0)
    parser.add_argument("--latest_method_notes", type=str, default="None method notes")
    
    # cluster params in yaml
    parser.add_argument("--use_cluster", type=int, default=1)
    parser.add_argument("--cluster_config", type=str, default="cluster_config/cluster_config.yaml")
    
    args = parser.parse_args()
    
    cluster_args = yaml.load(open(args.cluster_config, 'r'), Loader=yaml.FullLoader)

    os.makedirs(args.save_root, exist_ok=True)
    
    # world_size = args.world_size
    print('start training')
    print(f"Process local rank = {args.local_rank}")


    # deepspeed.launcher.executable.main(train, args=(world_size, args), nprocs=world_size)
    # mp.spawn(train, args=(world_size, args), nprocs=world_size, join=True)
    # os.environ['TOKENIZERS_PARALLELISM='] = "true"
    # os.environ['RANK'] = os.environ['SLURM_PROCID']
    # os.environ['WORLD_SIZE'] = os.environ['SLURM_NTASKS']
    # os.environ['MASTER_PORT'] = str(random.randint(1024, 65535))
    # os.environ['LOCAL_RANK'] = os.environ['SLURM_LOCALID']
    train(args, cluster_args)


if __name__ == '__main__':

    main()
