import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import TensorDataset, DataLoader
import random
import os
from tqdm import tqdm
import pandas as pd
from sklearn.utils import shuffle
import time
from torch.nn.utils import clip_grad_norm_
import json
import datetime
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

from deepspeed import get_accelerator
import logging
import deepspeed
import subprocess
import wandb
from sklearn.metrics import r2_score
import numpy as np
import sys
import pathlib

import os
import re
import shutil

import cluster_utils

def init_deepspeed_distributed():
    # Initialize DeepSpeed
    deepspeed.init_distributed()#default:dist_backend='nccl'


# 仅单卡串行任务
# def get_device():
#     return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 多卡并行任务/仅单卡串行任务
def get_device():
    # 获取当前任务的 GPU 卡
    local_rank = int(os.environ.get('LOCAL_RANK', 0))  # 获取当前任务的编号
    device_count = torch.cuda.device_count()  # 获取可用的 GPU 数量
    print(f'Local Rank:{local_rank}, Device Count:{device_count}')
    # 确保设备号在 GPU 数量范围内
    if device_count > 1:
        torch.cuda.set_device(local_rank)  # 设置当前任务使用的 GPU
        return torch.device(f'cuda:{local_rank}')
    else:
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_data_from_csv(path):
    data_df = pd.read_csv(path)
    dataset = Dataset.from_dict(data_df)
    return dataset

# lora_config = LoraConfig(
#     r=8,  # Rank of the low-rank matrix
#     lora_alpha=16,
#     lora_dropout=0.1,
#     target_modules=["q_proj", "v_proj",],
#     #target_modules="all-linear",
#     # bias="none",
#     # modules_to_save=["classifier"]
#     )
class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, hidden_size),
                            # self.act,
                            # nn.Linear(hidden_size, hidden_size//4),
                            # # self.act,
                            nn.Linear(hidden_size, 1),
                        )
    def forward(self, x):
        pred = self.predictor(x)
        # print(f'pred:{pred.view(-1)}')
        # print(f'y: {y.view(-1)}')
        return pred
    
class LlamaWithLoss(nn.Module):
    def __init__(self, llama, predictor):
        super(LlamaWithLoss, self).__init__()
        self.llama = llama
        self.loss_func = torch.nn.MSELoss()

        self.predictor = predictor
        
    def forward(self, inputs, y, pooling_method='last_token', return_loss=True):
        outputs = self.llama(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        if pooling_method == 'mean':
            embeddings = last_hidden_state.mean(dim=1)  # Mean pooling to get sentence-level embeddings
        elif pooling_method=='last_token':
            embeddings = last_hidden_state[:,-1,:]
        if return_loss:
            pred = self.predictor(embeddings)
            loss = self.loss_func(pred.view(-1),y.view(-1))
            return embeddings, loss
        else:
            pred = self.predictor(embeddings)
            return embeddings, pred
    def eval(self):
        return super().eval()
    
    
def main(args):
    
    # 标识
    model_method = args.searchspace_name
    
    # Initialize DeepSpeed distributed environment
    init_deepspeed_distributed()
    device = get_device()
    local_rank = int(os.environ.get('LOCAL_RANK', 0))  # 获取当前任务的编号
    World_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 
    print(f"local_rank: {local_rank}; World_size: {World_size}")
    print(f'[Rank {local_rank}] Starting inference on Device: {device}')
    
    # print(f"split id: {args.split_id}")
    # output_state_dict_path = os.path.join('./eval_results', args.data_name+f'_split_{args.split_id}')
    # Construct the model and load the checkpoint
    if args.load_by_torch:
        base_model = cluster_utils.load_model_with_state_by_torch(args.pretrained_model_path).to(device)
        print("Model load_by_torch")
    else:
        base_model = AutoModel.from_pretrained(args.pretrained_model_path).to(device)
        print("Model load_by_pretrain")
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
    if args.lora:
        lora_config = LoraConfig(
            r=8,  # Rank of the low-rank matrix
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj",],
        )
        base_model = get_peft_model(base_model,lora_config)
    
    predictor = YieldPredLayer(4096,1024,1).cuda()
    model = LlamaWithLoss(base_model, predictor)
    # model = model.cpu()
    if args.checkpoint_dir:
        state_dict = get_fp32_state_dict_from_zero_checkpoint(args.checkpoint_dir)
        model.load_state_dict(state_dict)
        print("Model loaded")
    # model = model.to(device)
    
    model.eval()
    
    # If we want to run multi-GPU inference with DeepSpeed data parallel,
    # we can use a simple deepspeed.initialize with an inference config.
    # Example minimal DS config (could be loaded from JSON):
    # ds_config = {
    #     "train_micro_batch_size_per_gpu": 1,
    #     "gradient_accumulation_steps": 1,
    #     "train_batch_size": 4,  # not relevant for inference, but needed
    #     "fp16": {
    #         "enabled": True
    #     }
    # }
    
    # model_engine, optimizer, _, _ = deepspeed.initialize(
    #     model=model, 
    #     model_parameters=None, # no optimizer needed for inference only
    #     config=ds_config
    # )
    
    # use the inference config to enable DeepSpeed data parallel
    ds_inference_config = {
        "replace_with_kernel_inject": False,
        "tensor_parallel": {"tp_size": 1},
        "dtype": "fp32", #"fp16"
        "enable_cuda_graph": False
    }
    
    model_engine = deepspeed.init_inference(
        model=model,
        config=ds_inference_config
    )        
    
    print("Distributed modle is already on correcy device")

    if "grouped_exp" in args.searchspace_name:
        searchspace_names = ['buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv', 'buchwald_Cc1ccc(Nc2ccccn2)cc1.csv', 'buchwald_Cc1ccc(Nc2cccnc2)cc1.csv', 'buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv', 'buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv']
    else:
        searchspace_names = [args.searchspace_name]

    for searchspace_name in searchspace_names:

        output_state_dict_path = os.path.join(args.save_path, 'eval_results', searchspace_name)
        os.makedirs(output_state_dict_path, exist_ok=True)
    
        # Load the dataset
        
        data = read_data_from_csv(
            os.path.join('./data4regression', searchspace_name, 'all.csv')
        )
        
        # We can shard the data among ranks if we want each rank to handle different slices.
        # TODO Now, the number of samples must be divisible by the number of ranks
        data_sampler = DistributedSampler(
            data, 
            num_replicas=World_size, 
            rank=local_rank,
            shuffle=False    
        )
        dataloader = DataLoader(
            data, 
            batch_size=args.batch_size,  # Use the batch size from the arguments
            sampler=data_sampler,
            num_workers=8, 
            shuffle=False
        )
        print("Dataset loaded")
        
        # Inference
        y_trues = []
        y_preds = []

        with torch.no_grad():
            # for i, batch in enumerate(tqdm(dataloader,desc=f"{args.split_id}",file=sys.stdout)):
            for i, batch in enumerate(tqdm(dataloader,desc=f"Rank {local_rank} inference   ",file=sys.stdout)):    
                prompts = batch['instruction'] #获取输入指令
                y = batch['output'].to(torch.float).to(device) #获取y值
                
                # 文本编码，对prompts进行编码，设置最大长度为3000，padding和截断
                inputs = tokenizer(prompts, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
                # 将 inputs 字典中的每一个张量移动到当前进程对应的GPU上，
                # 从而保证数据和模型在同一设备上，避免设备不匹配的错误。
                inputs = {k: v.to(device) for k,v in inputs.items()}
                y = y.to(device)
                
                emb, pred = model_engine(inputs, y, return_loss=False)
                y_preds.append(pred.detach().to(device))
                y_trues.append(y.detach().to(device))

                # inputs = tokenizer(prompts, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
                # #print(f"device: {device}; inputs device:{inputs.device}; model device:{model.device}")
                # emb, pred = model(inputs, None, return_loss=False)
                # y_preds.append(pred.detach().cpu())
        
                # y_trues.append(y.detach().cpu())
                
        # Gather results from all ranks to rank 0
        if torch.distributed.is_initialized():
            if local_rank == 0: 
                # rank 0创建长度为Word_size的列表，每个元素是一个0张量
                gathered_preds = [torch.zeros_like(torch.cat(y_preds, dim=0)) for _ in range(World_size)]
                gathered_labels = [torch.zeros_like(torch.cat(y_trues, dim=0)) for _ in range(World_size)]
                
                # Concatenate the results on rank 0
                dist.gather(
                    torch.cat(y_preds, dim=0),  # 计算结果拼接为一个一维张量
                    gather_list=gathered_preds, # 其他进程的所有数据汇聚与这个列表
                    dst=0 # 汇聚到的进程为0进程
                )
                dist.gather(
                    torch.cat(y_trues, dim=0), # 拼接为一个一维张量
                    gather_list=gathered_labels, # 其他进程的所有数据汇聚与这个列表
                    dst=0 # 汇聚到的进程为0进程
                )
                
            else:
                # 非目标进程与目标进程进行通信匹配，数据传输 
                dist.gather(
                    torch.cat(y_preds, dim=0),
                    gather_list=None,
                    dst=0
                )
                dist.gather(
                    torch.cat(y_trues, dim=0),
                    gather_list=None,
                    dst=0
                )
                            
            # # Concatenate the results on rank 0
            # dist.gather(
            #     torch.cat(y_preds, dim=0),  # 计算结果拼接为一个一维张量
            #     gather_list=gathered_preds, # 其他进程的所有数据汇聚与这个列表
            #     dst=0 # 汇聚到的进程为0进程
            # )
            # dist.gather(
            #     torch.cat(y_trues, dim=0), # 拼接为一个一维张量
            #     gather_list=gathered_labels, # 其他进程的所有数据汇聚与这个列表
            #     dst=0 # 汇聚到的进程为0进程
            # )

            # Only rank 0 will have the gathered results
            if local_rank == 0:
                y_preds_gathered = torch.cat(gathered_preds, dim=0) if len(gathered_preds) > 0 else None
                y_labels_gathered = torch.cat(gathered_labels, dim=0) if len(gathered_labels) > 0 else None

                # Saving embeddings and predictions
                print('Saving embeddings and predictions...')
                data_info_dict = {"pred_yields_by_rxn": y_preds_gathered}
                torch.save(data_info_dict, os.path.join(output_state_dict_path, f'{model_method}_yields.pt'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_path", type=str, default='bert-base-uncased', help="The path to the pretrained model")
    parser.add_argument("--searchspace_name", type=str, default='tongji_searchspace_v2_97020', help="The name of the search space.")
    parser.add_argument("--lora", type=int, help="Use LoRA")
    parser.add_argument("--data_name", type=str, default='cpa_100', help="The name of the dataset")
    parser.add_argument("--checkpoint_dir", type=str, help="The path to the checkpoint directory", default=None)
    parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed training")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference")
    parser.add_argument("--load_by_torch", type=int, default=1)
    parser.add_argument("--save_path", type=str, default="")
    # parser.add_argument("--split_id", type=int, help="The number of the spilted data files")
    # parser.add_argument("--deepspeed_config", type=str, default="ds_config.json")
    args = parser.parse_args()
    main(args)