import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.append(base_path)

import numpy as np
import matplotlib.pyplot as plt
from utils.COP_slover import TSP_lkh
from sklearn.datasets import make_blobs, make_circles, make_moons
from utils.utils import create_folder_overwrite_if_exist, create_folder_if_not_exist, load_existing_raw_data, merge_to_dataset
from environment.used.BaseEnv_COP import RawData
import multiprocessing
import pickle
import math
import time
import random
import concurrent.futures
from functools import partial
from tqdm import tqdm

def visualize_tsp_samples(tsp_data: np.ndarray, tsp_solution: np.ndarray = None, num_samples: int = 5, save_path: str = 'tsp_samples.png'):
    sample_indices = np.random.choice(len(tsp_data), num_samples, replace=False)
    plt.figure(figsize=(15, 5))

    for i, idx in enumerate(sample_indices):
        cities = tsp_data[idx]
        plt.subplot(1, num_samples, i+1)
        plt.scatter(cities[:, 0], cities[:, 1])

        if tsp_solution is not None:
            solution = tsp_solution[idx]
            solution = np.append(solution, solution[0])
            plt.plot(cities[solution, 0], cities[solution, 1], 'b-', alpha=0.25)  # Plot the solution path
            plt.scatter(cities[solution, 0], cities[solution, 1], color='blue')  # Highlight solution path

        plt.title(f'Sample {i+1}')
        plt.xlabel('X Coordinate')
        plt.ylabel('Y Coordinate')
        plt.gca().set_aspect('equal', adjustable='box')
        plt.xlim(0, 1)  # Set x-axis limits to [0, 1]
        plt.ylim(0, 1)  # Set y-axis limits to [0, 1]
        plt.grid(True)

    plt.tight_layout()
    plt.savefig(save_path)  # Save the plot to a file
    #plt.show()  # Close the plot to avoid displaying it

def generate_tsp_data(seed=42, dataset_size=10000, tsp_size=20, distribution='uniform'):
    rng = np.random.RandomState(seed)  # 创建独立的随机数生成器对象

    if distribution == 'uniform':
        data = rng.uniform(size=(dataset_size, tsp_size, 2))
    elif distribution == 'normal':
        data = rng.normal(loc=0.5, scale=1, size=(dataset_size, tsp_size, 2))
    elif distribution == 'exponential':
        data = rng.exponential(scale=0.5, size=(dataset_size, tsp_size, 2))
    elif distribution == 'gamma':
        data = rng.gamma(shape=2.0, scale=1.0, size=(dataset_size, tsp_size, 2))
    elif distribution == 'beta':
        data = rng.beta(a=2.0, b=8.0, size=(dataset_size, tsp_size, 2))
    elif distribution == 'clusters':
        X, y = make_blobs(n_samples=dataset_size*tsp_size, n_features=2, centers=2, cluster_std=[2, 2], random_state=seed)
        data = X.reshape((dataset_size, tsp_size, 2))
    elif distribution == 'circles':
        X, _ = make_circles(n_samples=dataset_size*tsp_size, noise=0.1, random_state=seed)
        data = X.reshape((dataset_size, tsp_size, 2))
    elif distribution == 'moons':
        X, _ = make_moons(n_samples=dataset_size*tsp_size, noise=0.1, random_state=seed)
        data = X.reshape((dataset_size, tsp_size, 2))
    else:
        raise ValueError("Invalid distribution type. Supported types are 'uniform', 'normal', 'exponential', 'gamma', 'beta', 'clusters', 'circles', 'moons', and 'grid'.")
        
    '''
    elif distribution == 'grid':
        grid_size = int(np.sqrt(tsp_size))  # 计算格栅大小
        num_extra_points = tsp_size - (grid_size * grid_size)  # 计算额外点的数量
        grid_data = np.empty((grid_size * grid_size, 2))
        x = np.linspace(0, 1, num=grid_size)  # 根据 tsp_size 生成 x 坐标
        y = np.linspace(0, 1, num=grid_size)  # 根据 tsp_size 生成 y 坐标
        X, Y = np.meshgrid(x, y)  # 生成格栅坐标
        grid_data[:, 0] = X.flatten()
        grid_data[:, 1] = Y.flatten()
        extra_data = np.random.uniform(size=(num_extra_points, 2))  # 生成额外的点
        data = np.concatenate((grid_data, extra_data), axis=0)  # 合并点
        data = np.tile(data, (dataset_size, 1, 1))  # 复制多个样本
    '''

    # 线性归一化到单位方块 [0, 0] 到 [1, 1] 内
    min_vals = np.min(data, axis=(0, 1), keepdims=True)
    max_vals = np.max(data, axis=(0, 1), keepdims=True)
    normalized_data = (data - min_vals) / (max_vals - min_vals)
    return normalized_data

def check_data(show_info=False, check_data_num=5000):
    ''' 检查生成的数据是否存在问题 '''
    # 加载源数据
    dataset = RawData(seed_list=[], problem_list=[], answer_list=[], cost_list=[])
    for data_type in ['problem', 'train', 'prompt']:
        path = f'{data_base_path}/data/used/_raw/tsp/tsp{node_num}({data_distribution})_{data_type}.pkl'
        if os.path.exists(path):
            with open(path, 'rb') as f:
                sub_dataset = pickle.load(f)
            dataset.seed_list.extend(sub_dataset.seed_list)
            dataset.problem_list.extend(sub_dataset.problem_list)
            dataset.answer_list.extend(sub_dataset.answer_list)
            dataset.cost_list.extend(sub_dataset.cost_list)

    assert len(dataset.problem_list) == len(dataset.answer_list) == len(dataset.cost_list), '数据长度不匹配'
    assert len(set(dataset.seed_list)) == len(dataset.seed_list), '存在使用相同随机种子的数据生成进程'
    positions = np.array([problem['position'] for problem in dataset.problem_list])         # (data_num, node_num, 2)
    assert positions.shape[0] == np.unique(positions, axis=0).shape[0], '存在完全相同的问题'

    if show_info:
        # 打印数据集信息
        print('-'*80)
        print(f'There are [{len(dataset.answer_list)}] raw data generated with [{len(dataset.seed_list)}] different random seeds, ave cost = [{np.mean(dataset.cost_list):.3f}]')
        print('-'*80)

    # 随机检查 check_data_num 个数据
    if check_data_num > 0:
        dataset_size = len(dataset.answer_list)
        check_idx = random.sample(range(dataset_size), min(dataset_size, check_data_num))
        for idx in tqdm(check_idx, desc=f"Check data validity"):
            data_pos = dataset.problem_list[idx]['position']
            data_answer = dataset.answer_list[idx]
            data_cost = dataset.cost_list[idx]
            cost, answer = TSP_lkh(np.array(data_pos, np.float32)) 
            assert abs(cost - data_cost) < 1e-5
            assert answer == data_answer

def gen_raw_data(data_num_per_process, data_distribution, seed, node_num, tqdm_pos, lock):
    ''' 数据生成进程的目标方法 '''
    # 生成问题信息
    process_data = generate_tsp_data(seed=seed, dataset_size=data_num_per_process, tsp_size=node_num, distribution=data_distribution)
    
    # 求解各个问题并写入 dataset
    dataset = RawData(seed_list=[seed,], problem_list=[], answer_list=[], cost_list=[])
    with tqdm(total=data_num_per_process, desc=f'Gen TSP data with seed {seed}', position=tqdm_pos) as pbar:
        for city_pos in process_data:
            lock.acquire()  
            cost, answer = TSP_lkh(np.array(city_pos, np.float32)) 
            lock.release()
            dataset.problem_list.append({'position': city_pos})
            dataset.answer_list.append(answer)
            dataset.cost_list.append(cost)
            pbar.update()

    # 保存该进程生成的原始数据
    dataset_path = f'{data_base_path}/data/used/_raw/tsp/temp/num{data_num_per_process}_seed{seed}.pkl'
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)
    return dataset_path

if __name__ == "__main__":
    # 设置参数
    node_num = 50                       # TSP城市数量
    data_distribution = 'uniform'       # TSP城市分布类型
    data_num_per_process = 200          # 每个进程保存的样本数量
    worker_num = 20                     # 同时执行的生成进程数量
    overwrite = False                   # 是否覆盖已有的数据
    data_num_dict = {
        'train': 250000,
        'problem': 15000,
        'prompt': 5000,
    }
    #data_base_path = base_path
    data_base_path = '/data3/XXX/toy-gato'

    # 在临时文件夹保存求解器可能产生的中间文件
    create_folder_overwrite_if_exist(f'{data_base_path}/temp') 

    # 在临时文件夹保存每个子进程生成的数据文件
    create_folder_if_not_exist(f'{data_base_path}/data/used/_raw/tsp/temp') 

    # 加载已经存在的数据
    ex_datasets = {}
    ex_seeds = []
    for data_type in ['problem', 'train', 'prompt']:
        ex_data_path = f'{data_base_path}/data/used/_raw/tsp/tsp{node_num}({data_distribution})_{data_type}.pkl'
        ex_dataset = load_existing_raw_data(ex_data_path, overwrite)
        ex_datasets[data_type] = ex_dataset
        ex_seeds.extend(ex_dataset.seed_list)

    # 生成新数据
    for data_type in ['problem', 'train', 'prompt']:
        target_data_num = data_num_dict[data_type]                              # 该类目标样本量   
        ex_dataset = ex_datasets[data_type]                                     # 该类已存在的数据集
        ex_data_num = len(ex_dataset.answer_list)                               # 该类已存在的样本量
        data_to_go = target_data_num - ex_data_num                              # 该类还要生成的样本量
        processes_num = math.ceil(data_to_go/data_num_per_process)              # 所需的进程数量（每个进程生成data_num_per_process个样本）
        real_data_num = ex_data_num + data_num_per_process * processes_num      # 执行完毕后得到的该类真实样本量

        # 打印数据集信息
        if data_to_go > 0:    
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, [{data_to_go}] to go')
            print('-'*100)
        else:
            print('-'*100)
            print(f'There are [{ex_data_num}] raw data for [{data_type}] generated, the requirements of [{target_data_num}] have already been met')
            print('-'*100)
            continue

        # 分阶段生成所有数据，每个阶段最多worker_num个进程并行，直到总进程数量达到processes_num
        # 每个阶段最后集合该阶段所有并行进程生成的数据，保存到target_dataset_path
        target_dataset_path = f'{data_base_path}/data/used/_raw/tsp/tsp{node_num}({data_distribution})_{data_type}.pkl'
        target_dataset_backup_path = f'{data_base_path}/data/used/_raw/tsp/tsp{node_num}({data_distribution})_{data_type}_backup.pkl'
        gen_data_num = ex_data_num
        stage_num = 0
        while processes_num > 0:
            # 该阶段各个进程使用的随机种子
            max_seed = -1 if ex_seeds == [] else max(ex_seeds)
            stage_worker_num = min(worker_num, processes_num)
            stage_data_num = data_num_per_process * stage_worker_num
            processes_num -= stage_worker_num
            seeds = list(range(max_seed+1, max_seed+1+stage_worker_num))
            ex_seeds.extend(seeds)
            
            # 创建一个 ThreadPoolExecutor 对象，在循环中提交任务
            print('='*50 + f' STAGE {stage_num}: seeds = [{max_seed+1}, {max_seed+1+stage_worker_num}) data {gen_data_num} ~ {gen_data_num + stage_data_num}/{target_data_num}' + '='*50)
            manager = multiprocessing.Manager()
            lock = manager.Lock()
            dataset_path_list = []
            #with concurrent.futures.ThreadPoolExecutor(max_workers=stage_worker_num) as executor:
            with concurrent.futures.ProcessPoolExecutor(max_workers=stage_worker_num) as executor:
                futures = {}
                for i, seed in enumerate(seeds):
                    pos = i % worker_num
                    futures[executor.submit(partial(gen_raw_data, data_num_per_process, data_distribution, seed, node_num, pos, lock))] = seed

                for future in concurrent.futures.as_completed(futures):
                    try:
                        dataset_path = future.result()
                        dataset_path_list.append(dataset_path)
                    except Exception as e:
                        print(f"{e}")

            # 等待所有进程执行完毕
            for future in concurrent.futures.as_completed(futures):
                future.result()
            time.sleep(2)

            # 将该阶段生成的所有子数据集合并
            gen_data_num += stage_data_num
            stage_num += 1
            merge_to_dataset(dataset_path_list, target_dataset_path, target_dataset_backup_path)

            # 数据合法性检查
            check_data(show_info=False, check_data_num=200)
            os.remove(target_dataset_backup_path)   # 通过检查则删除备份文件

    print('='*55 + f' ALL DONE ' + '='*55)
    check_data(show_info=True, check_data_num=1000)