import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import TensorDataset, DataLoader
import random
import os
from tqdm import tqdm
import pandas as pd
from sklearn.utils import shuffle
import time
from torch.nn.utils import clip_grad_norm_
import json
import datetime

from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

import logging
import numpy as np

import pathlib

import os
import re
from sklearn.metrics import mean_absolute_error as MAE
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


def save_model_and_losses(
    global_rank,
    losses=None, 
    base_model_save_path=None, 
    lora_adapter_save_path=None, 
    model=None, 
    use_lora=False, 
    save_root=None, 
    logger=None
):
    # 将损失保存到 CSV 文件
    model_to_save = model.module if hasattr(model, 'module') else model

    
    # 保存基础模型
    model_to_save.llama.base_model.save_pretrained(base_model_save_path)
    print(f"Base model saved to: {base_model_save_path}")
    logger.info(f"Base model saved to: {base_model_save_path}")
    
    # 如果使用 LoRA，保存 LoRA 适配器
    if use_lora and lora_adapter_save_path is not None:
        model_to_save.llama.save_pretrained(lora_adapter_save_path)
        print(f"LoRA adapter saved to: {lora_adapter_save_path}")
        logger.info(f"LoRA adapter saved to: {lora_adapter_save_path}")
    
    # 保存损失函数的状态字典（如果存在）
    if isinstance(model_to_save.predictor, torch.nn.Module) and save_root is not None:
        pathlib.Path(save_root).mkdir(parents=True, exist_ok=True) 
        torch.save(model_to_save.predictor.state_dict(), os.path.join(save_root, 'predictor.pt'))


def parse_filename(filename):
    """
    解析文件名，提取 epoch, step, loss
    :param filename: 文件名字符串
    :return: (epoch, step, loss) 元组
    """
    match = re.match(r'(\d+)-(\d+)-(\d+\.?\d*)', filename)
    if match:
        epoch = int(match.group(1))
        step = int(match.group(2))
        loss = float(match.group(3))
        return epoch, step, loss
    return None


def delete_files_wrt_loss(save_path, max_save_files=5, reverse=True):
    """
    保留 loss 最小的 k 个文件，并删除其他文件
    :param directory: 文件夹路径
    :param k: 保留的文件数量
    """
    files = os.listdir(save_path)
    parsed_files = []

    for file in files:
        parsed = parse_filename(file)
        if parsed:
            parsed_files.append((file, *parsed))

    if not parsed_files:
        return

    # 按 loss 排序
    parsed_files.sort(key=lambda x: x[3], reverse=reverse)
    print(parsed_files)
    # 保留前 k 个文件
    max_save_files = min(max_save_files, len(parsed_files))
    best_files = parsed_files[:max_save_files]

    # 删除其他文件
    for file, _, _, _ in parsed_files[max_save_files:]:
        file_path = os.path.join(save_path, file)
        for subfile in os.listdir(file_path):
            os.remove(os.path.join(file_path, subfile))
        os.rmdir(file_path)

    print(f"Kept {len(best_files)} best files with the smallest loss.")

def save_ckpt_with_limited_files(model, save_path, epoch, step, loss, max_save_files=5, reverse=True):
    ckpt_id = f'{epoch}-{step}-{loss}'
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)
    
    if max_save_files is not None:
        if dist.get_rank()==0:
            delete_files_wrt_loss(save_path, max_save_files, reverse=reverse)
    model.save_checkpoint(save_path, tag=ckpt_id)


def read_reaction(path,data_name):
    dataset = Dataset.load_from_disk(os.path.join(path, data_name))
    # raw_data = pd.read_csv(os.path.join(path, data_name, data_name + ".csv"))
    known_yields = dataset['yield']
    known_conditions = dataset['condition']
    reactions = dataset['reaction']
    return known_conditions, known_yields, reactions



def cleanup():
    dist.destroy_process_group()

class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, hidden_size),
                            # self.act,
                            # nn.Linear(hidden_size, hidden_size//4),
                            # # self.act,
                            nn.Linear(hidden_size, 1),
                        )
    def forward(self, x):
        pred = self.predictor(x)
        # print(f'pred:{pred.view(-1)}')
        # print(f'y: {y.view(-1)}')
        return pred

class LlamaWithLoss(nn.Module):
    def __init__(self, llama, predictor):
        super(LlamaWithLoss, self).__init__()
        self.llama = llama
        self.loss_func = torch.nn.MSELoss()

        self.predictor = predictor
        
    def forward(self, inputs, y, pooling_method='last_token', return_loss=True):
        outputs = self.llama(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        if pooling_method == 'mean':
            embeddings = last_hidden_state.mean(dim=1)  # Mean pooling to get sentence-level embeddings
        elif pooling_method=='last_token':
            embeddings = last_hidden_state[:,-1,:]
        if return_loss:
            pred = self.predictor(embeddings)
            loss = self.loss_func(pred.view(-1),y.view(-1))
            return embeddings, loss
        else:
            pred = self.predictor(embeddings)
            return embeddings, pred

def read_data_from_csv(path):
    data_df = pd.read_csv(path)
    dataset = Dataset.from_dict(data_df)
    return dataset

from collections import OrderedDict

def load_sharded_checkpoint(checkpoint_dir, device='cpu'):
    # 首先读取索引文件以获取所有分片的信息
    index_file_path = os.path.join(checkpoint_dir, 'pytorch_model.bin.index.json')
    with open(index_file_path, 'r') as f:
        index = json.load(f)

    # 初始化一个有序字典用于存储完整的 state dict
    full_state_dict = OrderedDict()

    # 记录每个 shard 文件需要加载的权重名称
    shard_files = {}
    for weight_name, shard_info in tqdm(index['weight_map'].items()):
        if shard_info not in shard_files:
            shard_files[shard_info] = []
        shard_files[shard_info].append(weight_name)

    # 使用 tqdm 包装遍历，为加载过程添加进度条
    for shard_file, weight_names in tqdm(shard_files.items(), desc="Loading shards", unit="shard"):
        shard_path = os.path.join(checkpoint_dir, shard_file)
        if os.path.isfile(shard_path):
            shard_state_dict = torch.load(shard_path, map_location=device)
            for weight_name in weight_names:
                full_state_dict[weight_name] = shard_state_dict[weight_name]

    return full_state_dict

def make_plot(y_test, y_pred, rmse, r2_score, mae, name):
    fontsize = 16
    fig, ax = plt.subplots(figsize=(8,8))
    r2_patch = mpatches.Patch(label="R2 = {:.3f}".format(r2_score), color="#5402A3")
    rmse_patch = mpatches.Patch(label="RMSE = {:.1f}".format(rmse), color="#5402A3")
    mae_patch = mpatches.Patch(label="MAE = {:.1f}".format(mae), color="#5402A3")
    plt.xlim(-5,105)
    plt.ylim(-5,105)
    plt.scatter(y_pred, y_test, alpha=0.2, color="#5402A3")
    plt.plot(np.arange(100), np.arange(100), ls="--", c=".3")
    plt.legend(handles=[r2_patch, rmse_patch, mae_patch], fontsize=fontsize)
    ax.set_ylabel('Measured', fontsize=fontsize)
    ax.set_xlabel('Predicted', fontsize=fontsize)
    ax.set_title(name, fontsize=fontsize)
    return fig

def train(args):
    # Load the model and tokenize
    pretrained_model_path = args.pretrained_model_path
    max_length = args.max_length

    data_path=args.data_path
    data_name = args.data_name
    # Save the base model
    lora_adapter_path = args.lora_adapter_path

    load_ds_ckpt_id = args.load_ds_ckpt_id

    use_lora = args.use_lora

    logging.basicConfig(
        filename=args.log_file,
        level=logging.INFO,
    )
    logger = logging.getLogger()

    # args.global_rank = torch.distributed.get_rank()
    # get_accelerator().set_device(args.local_rank)
    # device = torch.device(get_accelerator().device_name(), args.local_rank)
    # # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    # # torch.distributed.init_process_group(backend='nccl')
    # deepspeed.init_distributed()

    output_state_dict_path = os.path.join(args.save_root, load_ds_ckpt_id)
    os.makedirs(args.save_root, exist_ok=True)
    os.makedirs(output_state_dict_path, exist_ok=True)

    print('Load model...')
    logger.info('Load model...')
    
    if use_lora:
        # Define LoRA configuration
        if not os.path.exists(lora_adapter_path):
            model = AutoModel.from_pretrained(pretrained_model_path, device_map='auto')
            tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
            print(model)
            logger.info(model)
            # Apply LoRA to the model\
            print('LoRA configuring...')
            logger.info('LoRA configuring...')
            lora_config = LoraConfig(
                r=8,  # Rank of the low-rank matrix
                lora_alpha=16,
                lora_dropout=0.1,
                # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                target_modules="all-linear",
                # bias="none",
                # modules_to_save=["classifier"]
                )
            model = get_peft_model(model,lora_config)
   
    # world_size = torch.distributed.get_world_size()
    # rank = args.local_rank
    print('cuda is available:',torch.cuda.is_available())
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    print('using device', device)
    model.to(device)
    predictor = YieldPredLayer(4096,1024,1).to(device).train()  # 4096 => 1024 => 256 => 1, act= relu
    model = LlamaWithLoss(model, predictor)
    

    # 使用上述函数加载分片 checkpoint
    model.load_state_dict(load_sharded_checkpoint(args.state_dict_path),device)

    print(model)

    val_batch_size = args.batch_size
    print("Train_batch_size: ", val_batch_size)
    logger.info(f'Train_batch_size: {val_batch_size}')


    # Data
    print('Load data from...', os.path.join(data_path, data_name, 'molecular.csv'))
    logger.info(f'Load data ...')
    # train_data = Dataset.load_from_disk(os.path.join(data_path, data_name))
    data = read_data_from_csv(os.path.join(data_path, data_name, 'molecular.csv'))

    # DDP sampler
    dataloader = DataLoader(data, batch_size=val_batch_size)
    
    model.eval()
    model.llama.eval()
    model.predictor.eval()
    preds = []
    embs = []
    targets = []
    for i, batch_data in enumerate(tqdm(dataloader)):

        with torch.no_grad():
            prompts = batch_data['instruction']
           

            inputs = tokenizer(prompts, max_length=max_length, padding='longest', truncation=True, return_tensors="pt").to(device)
            emb, pred = model(inputs, None, return_loss=False)
        preds.append(pred.detach().cpu())
        embs.append(emb.detach().cpu())
    y_preds = torch.cat(preds, dim=0)
    y_embs = torch.cat(embs, dim=0)

    print('Saving embs and yields...')

    save_dict = {
        "cls_embs": y_embs,
    }
    torch.save(save_dict, os.path.join(output_state_dict_path,'embs.pt'))

    data_info_dict = torch.load(os.path.join(data_path, data_name,'split_idx.pt'),pickle_module=pd.compat.pickle_compat)
    data_info_dict["pred_yields_by_rxn"] = y_preds
    torch.save(data_info_dict, os.path.join(output_state_dict_path,'yields.pt'))



        

        

def main():
    parser = argparse.ArgumentParser(description="Distributed Data Parallel Training")
    
    parser.add_argument("--pretrained_model_path", default='/mnt/cache/Chemllm/reaction_condition_recommendation/src/step1_llama3_8b_0916_yearly_pistachio_ep3')
    # parser.add_argument("--pretrained_model_path", default='/mnt/hwfile/ai4chem/share/jianpeng/llama3_8b_0916_lora_yield_pred_ds')
    parser.add_argument("--lora_adapter_path", default="/mnt/hwfile/ai4chem/share/jianpeng/llama3_8b_0916_lora_yield_pred_ds/lora_adapter")
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--data_path", default='/mnt/petrelfs/chenjianpeng/cjp/LLaMA-Factory/train_regression/data4regression')
    parser.add_argument("--data_name", default='suzuki_miyaura_600')
    parser.add_argument("--save_root", default="./eval_results")
    parser.add_argument("--base_model_save_path", default="base_model")
    parser.add_argument("--lora_adapter_save_path", default="lora_adapter")
    parser.add_argument('--use_lora', type=int, default=1)
    parser.add_argument('--log_file', type=str, default="training_ds.log")
    parser.add_argument('--max_length', type=int, default=3000)
    parser.add_argument('--load_ds_ckpt_id', type=str, default=None)
    parser.add_argument('--state_dict_path', type=str, default='./trained_models')
    

    args = parser.parse_args()
    # world_size = args.world_size
    print('start eval')


    # deepspeed.launcher.executable.main(train, args=(world_size, args), nprocs=world_size)
    # mp.spawn(train, args=(world_size, args), nprocs=world_size, join=True)
    # os.environ['TOKENIZERS_PARALLELISM='] = "true"
    # os.environ['RANK'] = os.environ['SLURM_PROCID']
    # os.environ['WORLD_SIZE'] = os.environ['SLURM_NTASKS']
    # os.environ['MASTER_PORT'] = str(random.randint(1024, 65535))
    # os.environ['LOCAL_RANK'] = os.environ['SLURM_LOCALID']
    train(args)


if __name__ == '__main__':

    main()
