import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import TensorDataset, DataLoader
import random
import os
from tqdm import tqdm
import pandas as pd
from sklearn.utils import shuffle
import time
from torch.nn.utils import clip_grad_norm_
import json
import datetime
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

from deepspeed import get_accelerator
import logging
import deepspeed
import subprocess
import wandb
from sklearn.metrics import r2_score
import numpy as np
import sys
import pathlib

import os
import re
import shutil

from collections import OrderedDict
import torch
from transformers import AutoConfig, AutoModel

def load_model_with_state_by_torch(weights_path):
    config = AutoConfig.from_pretrained(weights_path)
    model = AutoModel.from_config(config)
    state_dict = torch.load(weights_path + '/pytorch_model.bin')

    new_sd = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith("model."):
            new_sd[k[6:]] = v

    model.load_state_dict(new_sd)
    return model

class JsonTextDataset(Dataset):
    def __init__(self, json_dir):
        self.paths = list(os.listdir(json_dir))
        self.json_dir = json_dir

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        file_path = os.path.join(self.json_dir, path)
        with open(file_path, "r") as f:
            json_data = json.load(f)
        data_maps = {}
        for data in json_data:
            name = data["name"]
            attrs = list(data.keys())
            att_string = ""
            for attr in attrs:
                att_string += f"{attr}: {data[attr]}\n"
            data_maps[name] = att_string
        return data_maps

def init_deepspeed_distributed():
    # Initialize DeepSpeed
    deepspeed.init_distributed()#default:dist_backend='nccl'


# 仅单卡串行任务
# def get_device():
#     return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 多卡并行任务/仅单卡串行任务
def get_device():
    # 获取当前任务的 GPU 卡
    local_rank = int(os.environ.get('LOCAL_RANK', 0))  # 获取当前任务的编号
    device_count = torch.cuda.device_count()  # 获取可用的 GPU 数量
    print(f'Local Rank:{local_rank}, Device Count:{device_count}')
    # 确保设备号在 GPU 数量范围内
    if device_count > 1:
        torch.cuda.set_device(local_rank)  # 设置当前任务使用的 GPU
        return torch.device(f'cuda:{local_rank}')
    else:
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_data_from_csv(path):
    data_df = pd.read_csv(path)
    dataset = Dataset.from_dict(data_df)
    return dataset

# lora_config = LoraConfig(
#     r=8,  # Rank of the low-rank matrix
#     lora_alpha=16,
#     lora_dropout=0.1,
#     target_modules=["q_proj", "v_proj",],
#     #target_modules="all-linear",
#     # bias="none",
#     # modules_to_save=["classifier"]
#     )
class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, hidden_size),
                            # self.act,
                            # nn.Linear(hidden_size, hidden_size//4),
                            # # self.act,
                            nn.Linear(hidden_size, 1),
                        )
    def forward(self, x):
        pred = self.predictor(x)
        # print(f'pred:{pred.view(-1)}')
        # print(f'y: {y.view(-1)}')
        return pred
    
class LlamaWithLoss(nn.Module):
    def __init__(self, llama, predictor):
        super(LlamaWithLoss, self).__init__()
        self.llama = llama
        self.loss_func = torch.nn.MSELoss()

        self.predictor = predictor
        
    def forward(self, inputs, y, pooling_method='last_token', return_loss=True):
        outputs = self.llama(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        if pooling_method == 'mean':
            embeddings = last_hidden_state.mean(dim=1)  # Mean pooling to get sentence-level embeddings
        elif pooling_method=='last_token':
            embeddings = last_hidden_state[:,-1,:]
        if return_loss:
            pred = self.predictor(embeddings)
            loss = self.loss_func(pred.view(-1),y.view(-1))
            return embeddings, loss
        else:
            pred = self.predictor(embeddings)
            return embeddings, pred
    def eval(self):
        return super().eval()
    
def get_attr_string_from_json_files(json_files_dir):
    """
    convert attrs in json to attr_string
    :param json_files_dir: json files dir
    :return: attr_string
    """
    results = {}
    for file in os.listdir(json_files_dir):
        if not file.endswith(".json"):
            continue
        file_path = os.path.join(json_files_dir, file)
        with open(file_path, "r") as f:
            json_data = json.load(f)
        data_maps = {}
        for data in json_data:
            name = data["name"]
            attrs = list(data.keys())
            att_string = ""
            for attr in attrs:
                if attr != "name":
                    att_string += f"{attr}: {data[attr]}\n"
            data_maps[name] = att_string
        
        results[file.split(".json")[0]] = data_maps

    return results
    
def main(args):
    
    # Initialize DeepSpeed distributed environment
    init_deepspeed_distributed()
    device = get_device()
    local_rank = int(os.environ.get('LOCAL_RANK', 0))  # 获取当前任务的编号
    World_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 
    print(f"local_rank: {local_rank}; World_size: {World_size}")
    print(f'[Rank {local_rank}] Starting inference on Device: {device}')
    
    if local_rank == 0:        
        logging.info(f"\nStart eval model : {args.checkpoint_dir} on dataset: {args.data_name}")
    
    # print(f"split id: {args.split_id}")
    with open(os.path.join(args.checkpoint_dir, "latest"), "r") as f:
        latest_content = f.read()
    output_state_dict_path = os.path.join(args.save_path, args.searchspace_name)
    os.makedirs(output_state_dict_path, exist_ok=True)
    # Construct the model and load the checkpoint
    if args.load_by_torch:
        base_model = load_model_with_state_by_torch(args.pretrained_model_path).to(device)
    else:
        base_model = AutoModel.from_pretrained(args.pretrained_model_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
    if args.lora:
        lora_config = LoraConfig(
            r=8,  # Rank of the low-rank matrix
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj",],
        )
        base_model = get_peft_model(base_model,lora_config)
    
    predictor = YieldPredLayer(4096,1024,1).cuda()
    model = LlamaWithLoss(base_model, predictor)
    # model = model.cpu()
    if args.checkpoint_dir:
        state_dict = get_fp32_state_dict_from_zero_checkpoint(args.checkpoint_dir)
        model.load_state_dict(state_dict)
        print("Model loaded")
    # model = model.to(device)
    
    model.eval()

    # use the inference config to enable DeepSpeed data parallel
    ds_inference_config = {
        "replace_with_kernel_inject": False,
        "tensor_parallel": {"tp_size": 1},
        "dtype": "fp32", #"fp16"
        "enable_cuda_graph": False
    }
    
    model_engine = deepspeed.init_inference(
        model=model,
        config=ds_inference_config
    )        
    
    print("Distributed modle is already on correcy device")

    with torch.no_grad():
        total_attr_strings_maps = get_attr_string_from_json_files(f"/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_data/{args.searchspace_name}")
        for name, attr_strings_maps in total_attr_strings_maps.items():
            data_maps = {}
            for item, attr_string in attr_strings_maps.items():
                inputs = tokenizer(attr_string, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
                embed, _ = model_engine(inputs, None, return_loss=False)
                data_maps[item] = embed
            torch.save(data_maps, os.path.join(output_state_dict_path, f"{name}_embeddings.pt"))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_path", type=str, default='bert-base-uncased', help="The path to the pretrained model")
    parser.add_argument("--searchspace_name", type=str, default='tongji_searchspace_v2_97020', help="The name of the search space.")
    parser.add_argument("--lora", type=int, help="Use LoRA")
    parser.add_argument("--data_name", type=str, default='cpa_100', help="The name of the dataset")
    parser.add_argument("--checkpoint_dir", type=str, help="The path to the checkpoint directory")
    parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed training")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
    parser.add_argument("--log_file", type=str, default='eval.log', help="Log file path")
    parser.add_argument("--save_path", type=str, default='/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/sft_model_resutls')
    parser.add_argument("--load_by_torch", type=int, default=0)
    # parser.add_argument("--split_id", type=int, help="The number of the spilted data files")
    # parser.add_argument("--deepspeed_config", type=str, default="ds_config.json")
    args = parser.parse_args()
    main(args)