##########################################################################################
# Machine Environment Config

DEBUG_MODE = False
USE_CUDA = True
CUDA_DEVICE_NUM = 0

##########################################################################################
# Path Config

import os
import sys
import logging
import torch
# —— 配置路径 —— 
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)
THIS_DIR = os.path.dirname(__file__)
if THIS_DIR not in sys.path:
    sys.path.insert(0, THIS_DIR)
os.chdir(os.path.dirname(os.path.abspath(__file__)))

##########################################################################################
# import

from TSPTrainer import TSPTrainer as Trainer
from TSPTester import TSPTester as Tester
from gen_inst import dataset_conf, generate_datasets

##########################################################################################
# parameters

env_params = {
    'problem_size': 500,
    'pomo_size': 8,
}

model_params = {
    'embedding_dim': 128,
    'sqrt_embedding_dim': 128**(1/2),
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'eval_type': 'argmax',
}
import time
import sys
import os
from datetime import datetime
import logging
import logging.config
import pytz
import numpy as np
import matplotlib.pyplot as plt
import json
import shutil
process_start_time = datetime.now(pytz.timezone("Asia/Seoul"))
result_folder = './result/' + process_start_time.strftime("%Y%m%d_%H%M%S") + '{desc}'
def get_result_folder():
    return result_folder


def set_result_folder(folder):
    global result_folder
    result_folder = folder
def create_logger(log_file=None):
    if 'filepath' not in log_file:
        log_file['filepath'] = get_result_folder()

    if 'desc' in log_file:
        log_file['filepath'] = log_file['filepath'].format(desc='_' + log_file['desc'])
    else:
        log_file['filepath'] = log_file['filepath'].format(desc='')

    set_result_folder(log_file['filepath'])

    if 'filename' in log_file:
        filename = log_file['filepath'] + '/' + log_file['filename']
    else:
        filename = log_file['filepath'] + '/' + 'log.txt'

    if not os.path.exists(log_file['filepath']):
        os.makedirs(log_file['filepath'])

    file_mode = 'a' if os.path.isfile(filename)  else 'w'

    root_logger = logging.getLogger()
    root_logger.setLevel(level=logging.INFO)
    formatter = logging.Formatter("[%(asctime)s] %(filename)s(%(lineno)d) : %(message)s", "%Y-%m-%d %H:%M:%S")

    for hdlr in root_logger.handlers[:]:
        root_logger.removeHandler(hdlr)

    # write to file
    fileout = logging.FileHandler(filename, mode=file_mode)
    fileout.setLevel(logging.INFO)
    fileout.setFormatter(formatter)
    root_logger.addHandler(fileout)

    # write to console
    console = logging.StreamHandler(sys.stdout)
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)
    root_logger.addHandler(console)

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += (val * n)
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count if self.count else 0


class LogData:
    def __init__(self):
        self.keys = set()
        self.data = {}

    def get_raw_data(self):
        return self.keys, self.data

    def set_raw_data(self, r_data):
        self.keys, self.data = r_data

    def append_all(self, key, *args):
        if len(args) == 1:
            value = [list(range(len(args[0]))), args[0]]
        elif len(args) == 2:
            value = [args[0], args[1]]
        else:
            raise ValueError('Unsupported value type')

        if key in self.keys:
            self.data[key].extend(value)
        else:
            self.data[key] = np.stack(value, axis=1).tolist()
            self.keys.add(key)

    def append(self, key, *args):
        if len(args) == 1:
            args = args[0]

            if isinstance(args, int) or isinstance(args, float):
                if self.has_key(key):
                    value = [len(self.data[key]), args]
                else:
                    value = [0, args]
            elif type(args) == tuple:
                value = list(args)
            elif type(args) == list:
                value = args
            else:
                raise ValueError('Unsupported value type')
        elif len(args) == 2:
            value = [args[0], args[1]]
        else:
            raise ValueError('Unsupported value type')

        if key in self.keys:
            self.data[key].append(value)
        else:
            self.data[key] = [value]
            self.keys.add(key)

    def get_last(self, key):
        if not self.has_key(key):
            return None
        return self.data[key][-1]

    def has_key(self, key):
        return key in self.keys

    def get(self, key):
        split = np.hsplit(np.array(self.data[key]), 2)

        return split[1].squeeze().tolist()

    def getXY(self, key, start_idx=0):
        split = np.hsplit(np.array(self.data[key]), 2)

        xs = split[0].squeeze().tolist()
        ys = split[1].squeeze().tolist()

        if type(xs) is not list:
            return xs, ys

        if start_idx == 0:
            return xs, ys
        elif start_idx in xs:
            idx = xs.index(start_idx)
            return xs[idx:], ys[idx:]
        else:
            raise KeyError('no start_idx value in X axis data.')

    def get_keys(self):
        return self.keys


class TimeEstimator:
    def __init__(self):
        self.logger = logging.getLogger('TimeEstimator')
        self.start_time = time.time()
        self.count_zero = 0

    def reset(self, count=1):
        self.start_time = time.time()
        self.count_zero = count-1

    def get_est(self, count, total):
        curr_time = time.time()
        elapsed_time = curr_time - self.start_time
        remain = total-count
        remain_time = elapsed_time * remain / (count - self.count_zero)

        elapsed_time /= 3600.0
        remain_time /= 3600.0

        return elapsed_time, remain_time

    def get_est_string(self, count, total):
        elapsed_time, remain_time = self.get_est(count, total)

        elapsed_time_str = "{:.2f}h".format(elapsed_time) if elapsed_time > 1.0 else "{:.2f}m".format(elapsed_time*60)
        remain_time_str = "{:.2f}h".format(remain_time) if remain_time > 1.0 else "{:.2f}m".format(remain_time*60)

        return elapsed_time_str, remain_time_str

    def print_est_time(self, count, total):
        elapsed_time_str, remain_time_str = self.get_est_string(count, total)

        self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
            count, total, elapsed_time_str, remain_time_str))


def util_print_log_array(logger, result_log: LogData):
    assert type(result_log) == LogData, 'use LogData Class for result_log.'

    for key in result_log.get_keys():
        logger.info('{} = {}'.format(key+'_list', result_log.get(key)))


def util_save_log_image_with_label(result_file_prefix,
                                   img_params,
                                   result_log: LogData,
                                   labels=None):
    dirname = os.path.dirname(result_file_prefix)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    _build_log_image_plt(img_params, result_log, labels)

    if labels is None:
        labels = result_log.get_keys()
    file_name = '_'.join(labels)
    fig = plt.gcf()
    fig.savefig('{}-{}.jpg'.format(result_file_prefix, file_name))
    plt.close(fig)


def _build_log_image_plt(img_params,
                         result_log: LogData,
                         labels=None):
    assert type(result_log) == LogData, 'use LogData Class for result_log.'

    # Read json
    folder_name = img_params['json_foldername']
    file_name = img_params['filename']
    log_image_config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), folder_name, file_name)

    with open(log_image_config_file, 'r') as f:
        config = json.load(f)

    figsize = (config['figsize']['x'], config['figsize']['y'])
    plt.figure(figsize=figsize)

    if labels is None:
        labels = result_log.get_keys()
    for label in labels:
        plt.plot(*result_log.getXY(label), label=label)

    ylim_min = config['ylim']['min']
    ylim_max = config['ylim']['max']
    if ylim_min is None:
        ylim_min = plt.gca().dataLim.ymin
    if ylim_max is None:
        ylim_max = plt.gca().dataLim.ymax
    plt.ylim(ylim_min, ylim_max)

    xlim_min = config['xlim']['min']
    xlim_max = config['xlim']['max']
    if xlim_min is None:
        xlim_min = plt.gca().dataLim.xmin
    if xlim_max is None:
        xlim_max = plt.gca().dataLim.xmax
    plt.xlim(xlim_min, xlim_max)

    plt.rc('legend', **{'fontsize': 18})
    plt.legend()
    plt.grid(config["grid"])


def copy_all_src(dst_root):
    # execution dir
    if os.path.basename(sys.argv[0]).startswith('ipykernel_launcher'):
        execution_path = os.getcwd()
    else:
        execution_path = os.path.dirname(sys.argv[0])

    # home dir setting
    tmp_dir1 = os.path.abspath(os.path.join(execution_path, sys.path[0]))
    tmp_dir2 = os.path.abspath(os.path.join(execution_path, sys.path[1]))

    if len(tmp_dir1) > len(tmp_dir2) and os.path.exists(tmp_dir2):
        home_dir = tmp_dir2
    else:
        home_dir = tmp_dir1

    # make target directory
    dst_path = os.path.join(dst_root, 'src')

    if not os.path.exists(dst_path):
        os.makedirs(dst_path)

    for item in sys.modules.items():
        key, value = item

        if hasattr(value, '__file__') and value.__file__:
            src_abspath = os.path.abspath(value.__file__)

            if os.path.commonprefix([home_dir, src_abspath]) == home_dir:
                dst_filepath = os.path.join(dst_path, os.path.basename(src_abspath))

                if os.path.exists(dst_filepath):
                    split = list(os.path.splitext(dst_filepath))
                    split.insert(1, '({})')
                    filepath = ''.join(split)
                    post_index = 0

                    while os.path.exists(filepath.format(post_index)):
                        post_index += 1

                    dst_filepath = filepath.format(post_index)

                shutil.copy(src_abspath, dst_filepath)

#########################################################################################################################################


def get_trainer_params(batch_size=64):
    return {
        'use_cuda': USE_CUDA,
        'cuda_device_num': CUDA_DEVICE_NUM,
        'epochs': 30,
        'train_episodes': 1000,
        'train_batch_size': batch_size,
        'logging': {
            'model_save_interval': 30,
            'img_save_interval': 30,
            'log_image_params_1': None,
            'log_image_params_2': None,
        },
        'model_load': {
            'enable': False,
        }
    }

def get_optimizer_params():
    return {
        'optimizer': {
            'lr': 1e-4,
            'weight_decay': 0
        },
        'scheduler': {
            'milestones': [3001],
            'gamma': 0.1
        }
    }

def get_tester_params(batch_size=64):
    return {
        'use_cuda': USE_CUDA,
        'cuda_device_num': CUDA_DEVICE_NUM,
        'model_load': {
            'path': './result/20250515_212431{desc}',
            'epoch': 30,
        },
        'test_episodes': 10,
        'test_batch_size': batch_size,
        'augmentation_enable': False,
        'aug_factor': 8,
        'aug_batch_size': 100,
    }

##########################################################################################
# main utility functions

def load_and_preprocess_dataset(dataset_path, batch_size):
    """
    加载并预处理数据集，确保批次大小一致
    
    :param dataset_path: 数据集文件路径
    :param batch_size: 期望的批次大小
    :return: 处理后的数据集张量
    """
    # 加载数据集
    dataset = torch.load(dataset_path)
    
    # 打印原始数据集形状
    print(f"原始数据集形状: {dataset.shape}")
    
    # 处理批次大小
    if len(dataset) != batch_size:
        if len(dataset) < batch_size:
            # 使用零填充
            padding = torch.zeros_like(dataset[0]).repeat(batch_size - len(dataset), 1, 1)
            dataset = torch.cat([dataset, padding], dim=0)
        else:
            # 截断
            dataset = dataset[:batch_size]
    
    # 打印处理后的数据集形状
    print(f"处理后数据集形状: {dataset.shape}")
    
    return dataset

def train_and_evaluate(problem_size, mood):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    """
    训练并评估模型
    
    :param problem_size: 问题规模
    :param mood: 训练/验证/测试模式
    :return: 平均目标值
    """
    basepath = os.path.dirname(__file__)
    
    # 自动生成数据集（如果不存在）
    if not os.path.isfile(os.path.join(basepath, f"dataset/train{dataset_conf['train'][0]}_dataset.pt")):
        generate_datasets()
    
    # 设置数据集路径
    dataset_path = os.path.join(basepath, f"dataset/{mood}{problem_size}_dataset.pt")
    
    # 动态调整批次大小
    if mood == 'train':
        batch_size = 10  # 训练集批次大小
    else:
        batch_size = 64  # 验证/测试集批次大小
    
    # 预处理数据集
    try:
        processed_dataset = load_and_preprocess_dataset(dataset_path, batch_size)
    except Exception as e:
        print(f"数据集预处理失败: {e}")
        raise
    
    # 更新环境参数
    env_params['test_file_path'] = dataset_path
    env_params['problem_size'] = problem_size
    
    # 动态生成参数
    current_trainer_params = get_trainer_params(batch_size)
    current_optimizer_params = get_optimizer_params()
    current_tester_params = get_tester_params(batch_size)
    
    # 训练阶段
    if mood == 'train':
        # 修改训练参数，确保每个epoch都保存
        current_trainer_params['logging']['model_save_interval'] = 30
        
        # 初始化训练器
        trainer = Trainer(
            env_params=env_params,
            model_params=model_params,
            optimizer_params=current_optimizer_params,
            trainer_params=current_trainer_params
        )
        # print("Using device:", device)
        
        # 训练模型
        trainer.run()
        
        # 获取训练结束的epoch号
        final_epoch = current_trainer_params['epochs']
        
        # 在训练集上评估 - 使用刚训练的模型
        current_tester_params['model_load'] = {
            'path': trainer.result_folder,  # 使用训练器保存的结果文件夹
            'epoch': final_epoch,           # 使用最终的epoch
        }
        
        print(f"[*] 将使用刚训练完的模型进行评估: {trainer.result_folder}/checkpoint-{final_epoch}.pt")
        
        # 检查检查点文件是否存在
        checkpoint_path = f"{trainer.result_folder}/checkpoint-{final_epoch}.pt"
        if not os.path.exists(checkpoint_path):
            print(f"[!] 警告: 找不到刚训练的检查点: {checkpoint_path}")
            # 搜索可能的检查点
            if os.path.exists(trainer.result_folder):
                possible_checkpoints = [f for f in os.listdir(trainer.result_folder) if f.startswith("checkpoint-") and f.endswith(".pt")]
                if possible_checkpoints:
                    # 找到最新的检查点
                    latest_checkpoint = sorted(possible_checkpoints, key=lambda x: int(x.replace("checkpoint-", "").replace(".pt", "")))[-1]
                    latest_epoch = int(latest_checkpoint.replace("checkpoint-", "").replace(".pt", ""))
                    print(f"[*] 找到替代检查点: {trainer.result_folder}/{latest_checkpoint}")
                    current_tester_params['model_load']['epoch'] = latest_epoch
        
        current_tester_params['test_episodes'] = 10
        tester = Tester(
            env_params=env_params,
            model_params=model_params,
            tester_params=current_tester_params
        )
        avg_obj = tester.run()
        
        print("[*] 训练集平均目标值:")
        print(avg_obj)
        
        return avg_obj
    
    # 验证/测试阶段
    else:
        current_tester_params['test_episodes'] = 64
        
        tester = Tester(
            env_params=env_params,
            model_params=model_params,
            tester_params=current_tester_params
        )
        
        avg_obj = tester.run()
        print(f"[*] {problem_size}问题规模的平均目标值: {avg_obj}")
        
        return avg_obj

def main():
    import sys
    import os
    
    print("[*] 开始运行 ...")
    create_logger({'desc': 'eval_run', 'filename': 'log.txt'})
    # 检查命令行参数
    if len(sys.argv) < 4:
        print("用法: python eval.py <问题规模> <模式> <train/val/test>")
        sys.exit(1)
    
    problem_size = int(sys.argv[1])
    mood = sys.argv[3]
    
    # 断言模式的有效性
    assert mood in ['train', 'val', "test"], "模式必须是 'train', 'val' 或 'test'"
    
    # # 检查检查点文件
    # basepath = os.path.dirname(__file__)
    # if not os.path.isfile(os.path.join(basepath, "checkpoints/checkpoint-3100.pt")):
    #     print("未找到检查点。请查看 readme.md 并下载检查点。")
    #     sys.exit(1)
    
    # 调用训练和评估函数
    result = train_and_evaluate(problem_size, mood)
    return result

##########################################################################################

if __name__ == "__main__":
    main()