import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import TensorDataset, DataLoader
import random
import os
from tqdm import tqdm
import pandas as pd
from sklearn.utils import shuffle
import time
from torch.nn.utils import clip_grad_norm_
import json
import datetime
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

from deepspeed import get_accelerator
import logging
import deepspeed
import subprocess
import wandb
from sklearn.metrics import r2_score
import numpy as np
import sys
import pathlib

import os
import re
import shutil

# 仅单卡串行任务
# def get_device():
#     return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 多卡并行任务/仅单卡串行任务
def get_device():
    # 获取当前任务的 GPU 卡
    local_rank = int(os.environ.get('SLURM_LOCALID', 0))  # 获取当前任务的编号
    device_count = torch.cuda.device_count()  # 获取可用的 GPU 数量
    print(local_rank, device_count)
    # 确保设备号在 GPU 数量范围内
    if device_count > 1:
        torch.cuda.set_device(local_rank)  # 设置当前任务使用的 GPU
        return torch.device(f'cuda:{local_rank}')
    else:
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_data_from_csv(path):
    data_df = pd.read_csv(path)
    dataset = Dataset.from_dict(data_df)
    return dataset

lora_config = LoraConfig(
    r=8,  # Rank of the low-rank matrix
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj",],
    #target_modules="all-linear",
    # bias="none",
    # modules_to_save=["classifier"]
    )
class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, hidden_size),
                            # self.act,
                            # nn.Linear(hidden_size, hidden_size//4),
                            # # self.act,
                            nn.Linear(hidden_size, 1),
                        )
    def forward(self, x):
        pred = self.predictor(x)
        # print(f'pred:{pred.view(-1)}')
        # print(f'y: {y.view(-1)}')
        return pred
class LlamaWithLoss(nn.Module):
    def __init__(self, llama, predictor):
        super(LlamaWithLoss, self).__init__()
        self.llama = llama
        self.loss_func = torch.nn.MSELoss()

        self.predictor = predictor
        
    def forward(self, inputs, y, pooling_method='last_token', return_loss=True):
        outputs = self.llama(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        if pooling_method == 'mean':
            embeddings = last_hidden_state.mean(dim=1)  # Mean pooling to get sentence-level embeddings
        elif pooling_method=='last_token':
            embeddings = last_hidden_state[:,-1,:]
        if return_loss:
            pred = self.predictor(embeddings)
            loss = self.loss_func(pred.view(-1),y.view(-1))
            return embeddings, loss
        else:
            pred = self.predictor(embeddings)
            return embeddings, pred
    def eval(self):
        return super().eval()
def main(args):
    print(f"split id: {args.split_id}")
    device = get_device()
    output_state_dict_path = os.path.join('./eval_results', args.data_name+f'_split_{args.split_id}')
    os.makedirs(output_state_dict_path, exist_ok=True)
    # Construct the model and load the checkpoint
    model = AutoModel.from_pretrained(args.pretrained_model_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path)
    if args.lora:
        model = get_peft_model(model,lora_config)
    state_dict = get_fp32_state_dict_from_zero_checkpoint(args.checkpoint_dir)
    predictor = YieldPredLayer(4096,1024,1).cuda()
    model = LlamaWithLoss(model, predictor)
    model = model.cpu()
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    print("Model loaded")
    # Load the dataset
    data = read_data_from_csv(os.path.join('./data4regression', 'tongji_searchspace', f'split_{args.split_id}.csv'))
    dataloader = DataLoader(data, batch_size=1, shuffle=False)
    print("Dataset loaded")
    # Inference
    y_true = []
    y_pred = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader,desc=f"{args.split_id}",file=sys.stdout)):
            prompts = batch['instruction']
            y = batch['output'].to(torch.float).to(device)

            inputs = tokenizer(prompts, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
            #print(f"device: {device}; inputs device:{inputs.device}; model device:{model.device}")
            emb, pred = model(inputs, None, return_loss=False)
            y_pred.append(pred.detach().cpu())
      
            y_true.append(y.detach().cpu())    
    y_labels= torch.cat(y_true, dim=0)
    y_preds = torch.cat(y_pred, dim=0)


#Save the embeddings and yields
    print('Saving embs and yields...')




    data_info_dict = {}
    data_info_dict["pred_yields_by_rxn"] = y_preds
    torch.save(data_info_dict, os.path.join(output_state_dict_path,'yields.pt'))






if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_path", type=str, default='bert-base-uncased', help="The path to the pretrained model")
    parser.add_argument("--lora", type=int, help="Use LoRA")
    parser.add_argument("--data_name", type=str, default='cpa_100', help="The name of the dataset")
    parser.add_argument("--checkpoint_dir", type=str, help="The path to the checkpoint directory")
    parser.add_argument("--split_id", type=int, help="The number of the spilted data files")
    args = parser.parse_args()
    main(args)