import os
import argparse
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.multiprocessing as mp
from math import log, exp
import ot
import pickle
import pandas as pd
import time
sys.path.append('..')
from src.utils import set_device_seed, scip_env, BranchAgent
from src.logger import Logger
logger = Logger.logger



# 这个函数怎么有点像收集问题示例的
def do_rollouts(args, agent, random_seeds, return_queue, env, are_negative):
    """
    For each model, do a rollout.
    对每一个模型做一个展示。
    """
    torch.set_num_threads(1)
    all_returns = []
    all_frames = []
    all_weights = []
    all_samples = []

    # model应该是待求解的问题示例,
    for model in env:
        nodes, tree = agent.solve(model)
        #当新颖性被激活时，从区域分区中采样
        # when novelty is activated, sample from the regions partition
        if args.alpha > 0:
            weights, samples = sample_from_tree(tree)
        else:
            weights, samples = [], []
        all_frames.append(nodes)
        all_weights.append(weights)
        all_samples.append(samples)
        nodes = float(log(nodes + 1))
        all_returns.append(nodes)
    returns = [ -exp(sum(all_returns)/len(all_returns)) + 1]
    all_frames = [sum(all_frames)]
    return_queue.put((random_seeds, returns, all_frames, [all_weights], [all_samples], are_negative))


def sample_from_tree(tree):
    weights = []
    root = tree.get_node(tree.root)
    samples = np.zeros((len(tree.leaves()), root.data['num']), dtype=np.int8)
    queue = [tree.root]
    while queue:
        parent = queue.pop(0)
        parent_node = tree.get_node(parent)
        vars, bounds, types = parent_node.data['vars'], parent_node.data['bounds'], parent_node.data['types']
        children = tree.children(parent)
        # if the node is leaf, sample a point in this region
        if not children:
            idx = len(weights)
            for var, bound in zip(vars, bounds):
                if bound == 1:
                    samples[idx, var] = 1
                else:
                    samples[idx, var] = -1
            weights.append(2**(-tree.depth(parent)))
        # if the node has children, update the bound change in children
        for c in children:
            c.data['vars'] += vars
            c.data['bounds'] += bounds
            c.data['types'] += types
            queue.append(c.identifier)
    return weights, samples


def knn_novelty(all_weights, all_samples, k):
    '''
    all_*s： 不同参数集的结果列表
    *s： 不同实例上一组参数的结果列表
    *： 一个实例上区域分区的权重和样本列表   
    return： 每组参数的knn新颖性分数
    all_*s: list of results for different set of parameters
    *s: list of results one set of parameters on different instances
    *: list of weights ans samples for the region partition on one instance
    return: knn novelty score for each set of parameters
    '''
    distance_matrix = np.zeros((len(all_weights), len(all_weights)))
    for i, (i_weights, i_samples) in enumerate(zip(all_weights, all_samples)):
        for j, (j_weights, j_samples) in enumerate(zip(all_weights[:i], all_samples[:i])):
            for i_weight, i_sample, j_weight, j_sample in zip(i_weights, i_samples, j_weights, j_samples):
                a = i_weight / np.sum(i_weight)
                b = j_weight / np.sum(j_weight)
                i_mask = np.abs(i_sample).sum(axis=0) > 0
                j_mask = np.abs(j_sample).sum(axis=0) > 0
                mask = i_mask + j_mask
                M = np.sum(np.abs(np.expand_dims(i_sample[:,mask], 1) - np.expand_dims(j_sample[:,mask], 0)), axis=2)
                # M = np.sum(np.abs(np.expand_dims(i_sample,1) - np.expand_dims(j_sample,0)), axis=2)
                d = ot.emd2(a, b, M)
                distance_matrix[i,j] += d
    distance_matrix += distance_matrix.T
    distance_matrix.sort(axis=1)
    knn_distance = np.sum(distance_matrix[:, :k], axis=1)

    return knn_distance

# 高斯更新，更新模型的参数
def gaussian_update(args, synced_agent, returns, all_weights, all_samples, random_seeds, neg_list,
                    num_eps, num_frames, unperturbed_results):
    def fitness_shaping(returns):
        """
        奖励的等级转换，减少了在训练早期陷入局部最优的概率。
        A rank transformation on the rewards, which reduces the chances
        of falling into local optima early in training.
        """
        # 对returns进行降序排序，并通过[::-1]，将降序的returns转化为升序
        sorted_returns_backwards = sorted(returns)[::-1]
        lamb = len(returns)
        shaped_returns = []
        denom = sum([max(0, log(lamb / 2 + 1, 2) - log(sorted_returns_backwards.index(r) + 1, 2)) for r in returns])
        for r in returns:
            num = max(0, log(lamb / 2 + 1, 2) - log(sorted_returns_backwards.index(r) + 1, 2))
            shaped_returns.append(num / denom + 1 / lamb)
        return shaped_returns


    def unperturbed_rank(returns, unperturbed_results):
        nth_place = 1
        for r in returns:
            if r > unperturbed_results:
                nth_place += 1
        rank_diag = ('%d out of %d (1 means gradient is uninformative)' % (nth_place, len(returns) + 1))
        return rank_diag, nth_place

    # 得到shaped_returns，
    # returns: [-7924.626032055767, -7730.0930663134495]
    # shaped_returns: [0.5, 1.5]

    # 归一化处理returns，以避免模型训练早期陷入局部最优解
    shaped_returns = fitness_shaping(returns)

    # shaped_novelty : 每组参数的新颖性分数
    if args.alpha > 0:
        shaped_novelty = fitness_shaping(knn_novelty(all_weights, all_samples, int(args.n/2)))
    else:
        shaped_novelty = [0 for _ in shaped_returns]

    print("shaped_novelty:",shaped_novelty)
    # shaped_novelty: [1.0, 1.0]


    # 意思也就是说，默认只有一个unperturbed的参数，len(rewards)=2n+1。
    # unperturbed_rank 就是初始参数在总参数的排名。

    rank_diag, rank = unperturbed_rank(returns, unperturbed_results)
    # rewards: [-7028.912872859809, -7582.405831155286, -7500.317084352581]
    # unperturbed_rank : -7795.182399097641
    # rank_diag: 2 out of 3 (1 means gradient is uninformative)
    # rank: 2
    print("returns:",returns)
    print("unperturbed_results:",unperturbed_results)
    print("rank_diag:",rank_diag)
    print("rank:",rank)

    assert 1==2

    # 如果not slient的话，则打印相关信息
    if not args.silent:
        logger.info('Episode num: %d\n'
            'Average reward: %f\t Stdev in rewards: %f\n'
            'Max reward: %f\t Min reward: %f\n'
            'Sigma: %f\t Learning rate: %f\t Total frames: %d \n'
            'Unperturbed reward: %f\t Unperturbed rank: %s\n\n' %
              (num_eps, -np.mean(returns), np.std(returns), max(returns), min(returns),
               args.sigma, args.lr, num_frames, unperturbed_results, rank_diag))

    # 对于每个模型，生成与我们之前所做的相同的随机数，并更新参数。我们应用一次权值衰减。
    # For each model, generate the same random numbers as we did
    # before, and update parameters. We apply weight decay once.
    for i in range(args.n):
        np.random.seed(random_seeds[i])
        multiplier = -1 if neg_list[i] else 1
        reward = shaped_returns[i]

        # 当前的适应度分数 = alpha * 新颖性分数 + (1-alpha)*reward
        score = args.alpha * shaped_novelty[i] + (1-args.alpha) * reward

        # 根据 score 更新模型参数
        for k, v in synced_agent.policy_params():
            eps = np.random.normal(0, 1, v.size())
            v += torch.from_numpy(args.lr / args.n *
                                  (score * multiplier * eps)).float()
    
    return synced_agent

def perturb_agent(args, agent, random_seed):
    """
    用参数的扰动和负扰动修改给定模型，并返回两个扰动模型。
    Modifies the given model with a pertubation of its parameters,
    as well as the negative perturbation, and returns both perturbed
    models.
    """
    new_agent = BranchAgent()
    anti_agent = BranchAgent()
    new_agent.policy.load_state_dict(agent.policy.state_dict())
    anti_agent.policy.load_state_dict(agent.policy.state_dict())
    np.random.seed(random_seed)
    for (k, v), (anti_k, anti_v) in zip(new_agent.policy_params(),
                                        anti_agent.policy_params()):
        eps = np.random.normal(0, 1, v.size())
        v += torch.from_numpy(args.sigma * eps).float()
        anti_v += torch.from_numpy(args.sigma * -eps).float()
    return [new_agent, anti_agent]




def generate_seeds_and_agents(args, synced_model):
    """
    返回一个种子和两个扰动模型
    Returns a seed and 2 perturbed models
    """
    np.random.seed()
    random_seed = np.random.randint(2 ** 30)
    two_agents = perturb_agent(args, synced_model, random_seed)
    return random_seed, two_agents


# 验证函数
# 在训练之前设置验证多进程,有几个工作进程就设置几个验证函数进程,
# 默认为2
def validation(args, valid_env, synced_agent, chkpt_dir, epoch, num_eps, num_frame, best_result=None):
    
    def flatten(raw_results, index):
        notflat_results = [result[index] for result in raw_results]
        return [item for sublist in notflat_results for item in sublist]

    processes = []
    return_queue = mp.Queue()
    valid_dirs = valid_env.dirs
    m = int(len(valid_dirs) / args.n)


    for j in range(args.n):
        # valid_env.list 为，所有的测试集
        valid_env.list = valid_dirs[j * m:(j + 1) * m]
        valid_env.batch_size = m
        p = mp.Process(target=do_rollouts, args=(args, synced_agent, ['dummy_seed'], return_queue, valid_env,
                                                 ['dummy_neg']))
        p.start()
        processes.append(p)


    # 获得所有工作线程的结果，一个数量为args.n的数据，每个数组返回以下内容，
    # 结果包括: random_seeds, returns, all_frames, [all_weights], [all_samples], are_negative
    # 这些数据是，所有的输出信息，
    raw_results = [return_queue.get() for p in processes]

    for p in processes:
        p.join()

    # 通过flatten函数得到raw_results里元素[1]的数据，也就是returns，返回的分数的集合
    results = flatten(raw_results, 1)
    # 将returns变成了np.array格式，
    result = np.array(results)
    # 将所有线程的returns求平均之后，计算np.exp(np.log(a))=a，我不知道意义何在
    result = -np.exp(np.log(-result+1).mean()) + 1

    # store data for plot 存储模型的信息，
    df.loc[len(df)] = [num_eps, num_frame, result]

    # 如果没有输出文件夹，则创立输出文件夹
    if not os.path.isdir(f'../log/{args.ins_type}'):
        os.makedirs(f'../log/{args.ins_type}')
    # 存储模型信息df
    if args.alpha > 0:
        df.to_csv(os.path.join(f'../log/{args.ins_type}', f'nes_{args.seed}se.csv'))
    else:
        df.to_csv(os.path.join(f'../log/{args.ins_type}', f'es_{args.seed}se.csv'))

    # 存储check_points
    # store check_points
    # 如果没有best_result，则初始化best_result
    if best_result is None:
        logger.info('initial best result is: {:.6f}'.format(result))
        return result
    
    # 如果args.alpha>0，则将模型参数存储为nes文件，否则将模型参数存储为es文件，
    if args.alpha > 0:
        torch.save(synced_agent.policy.state_dict(), os.path.join(chkpt_dir, f'nes_{epoch+1}.pth'))
    else:
        torch.save(synced_agent.policy.state_dict(), os.path.join(chkpt_dir, f'es_{epoch+1}.pth'))
    
    # 如果当前result优于历史最优result，则将模型参数存储为latest.pth文件，并返回当前的result(历史最优)
    if result > best_result:
        logger.info('valid result {:.6f}, improvement found at epoch {}'.format(result, epoch + 1))
        if args.alpha > 0:
            torch.save(synced_agent.policy.state_dict(), os.path.join(chkpt_dir, f'nes_latest.pth'))
        else:
            torch.save(synced_agent.policy.state_dict(), os.path.join(chkpt_dir, f'es_latest.pth'))
        return result
    # 如果result并没有提升，则返回历史最优result
    else:
        args.alpha *= args.alpha_decay
        logger.info('valid result {:.6f}, no improvment found at epoch {}'.format(result, epoch + 1))
        return best_result


# 训练过程
def train_loop(args, synced_agent, env, valid_env, chkpt_dir):

    # args 项目的参数 args = parser.parse_args()
    # synced_agent 分支agent策略 agent = BranchAgent(lr=args.sigma, device=None, check_point=args.restore) 定义智能体
    # env scip的环境 env = scip_env(train_dirs, train_vals, args.ins_batch_size, scip_seed=args.seed, seed=args.seed, timelimit=args.timelimit)
    # valid_env scip测试的环境 valid_env = scip_env(valid_dirs, valid_vals, args.ins_batch_size, scip_seed=args.seed, seed=args.seed, timelimit=args.timelimit)
    # chkpt_dir 输出的文件夹 chkpt_dir = f'../check_points/{args.ins_type}/'

    # flatten 函数似乎是从所有工作线程当中找到最优的结果
    def flatten(raw_results, index):
        notflat_results = [result[index] for result in raw_results]
        return [item for sublist in notflat_results for item in sublist]

    # 输出网络的参数个数
    logger.info("Num params in network %d" % synced_agent.policy.count_parameters())
    num_eps = 0
    total_num_frames = 0


    # 使用求解验证集问题，并得到结果，结果为，搜索树的节点数量的exp乘-1，结果越大越好
    best_result = validation(args, valid_env, synced_agent, chkpt_dir, 0, num_eps, total_num_frames, None)


    # 梯度更新次数
    # args.max_gradient_updates : 最大更新次数
    for epoch in range(args.max_gradient_updates):
        env.reset()
        processes = []
        # 返回结果的队列
        return_queue = mp.Queue()
        all_seeds, all_agents = [], []
        # Generate a perturbation and its antithesis 产生扰动和它的对偶
        # 每次随机一个高斯分布，通过添加正扰动和负扰动的方式，来生成两个模型，因此只需要n/2的次数
        for j in range(int(args.n / 2)):
            random_seed, two_agents = generate_seeds_and_agents(args, synced_agent)
            # Add twice because we get two models with the same seed
            # 加两次，因为我们得到了两个具有相同种子的模型
            all_seeds.append(random_seed)
            all_seeds.append(random_seed)
            all_agents += two_agents
        assert len(all_seeds) == len(all_agents)
        # 跟踪哪些扰动是正的，哪些是负的 从负值true开始，因为pop()使我们后退
        # Keep track of which perturbations were positive and negative
        # Start with negative true because pop() makes us go backwards
        is_negative = True
        # 将所有被打乱的模型添加到队列中
        # Add all peturbed models to the queue

        while all_agents:
            perturbed_agent = all_agents.pop()
            seed = all_seeds.pop()
            p = mp.Process(target=do_rollouts, args=(args, perturbed_agent, [seed], return_queue, env, [is_negative]))
            p.start()
            processes.append(p)
            is_negative = not is_negative

        assert len(all_seeds) == 0
        # Evaluate the unperturbed model as well
        # 同时评估未受扰动的模型
        p = mp.Process(target=do_rollouts, args=(args, synced_agent, ['dummy_seed'], return_queue, env, ['dummy_neg']))
        p.start()
        processes.append(p)
        # raw_results 应该是大家的输出啦
        raw_results = [return_queue.get() for p in processes]
        for p in processes:
            p.join()
        
        # 输出数据的汇总
        seeds, rewards, num_frames, all_weights, all_samples, neg_list = [flatten(raw_results, index)
                                                for index in [0, 1, 2, 3, 4, 5]]

        # 将未受扰动的结果与受扰动的结果分开
        # Separate the unperturbed results from the perturbed results
        # 将假种子全部pop掉，但是没有找到哪里加了假种子
        _ = unperturbed_index = seeds.index('dummy_seed')
        seeds.pop(unperturbed_index)
        print("rewards:",rewards)
        unperturbed_rewards = rewards.pop(unperturbed_index)

        print("rewards:",rewards)
        print("unperturbed_rewards:",unperturbed_rewards)


        _ = num_frames.pop(unperturbed_index)
        _ = neg_list.pop(unperturbed_index)
        _ = all_weights.pop(unperturbed_index)
        _ = all_samples.pop(unperturbed_index)

        total_num_frames += sum(num_frames)
        num_eps += len(rewards) * args.ins_batch_size
        synced_agent = gaussian_update(args, synced_agent, rewards, all_weights, all_samples, seeds, neg_list, num_eps, total_num_frames,
                                       unperturbed_rewards)

        if (epoch+1) % args.interval == 0:
            best_result = validation(args, valid_env, synced_agent, chkpt_dir, epoch, num_eps, total_num_frames, best_result)


parser = argparse.ArgumentParser(description='ES')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
parser.add_argument('--sigma', type=float, default=0.001, metavar='SD', help='noise standard deviation')
parser.add_argument('--n', type=int, default=2, metavar='N', help='worker batch size, must be even')
parser.add_argument('--max_gradient_updates', type=int, default=100000, metavar='MGU', help='maximum number of updates')
parser.add_argument('--restore', default='', help='checkpoint from which to restore')
parser.add_argument('--silent', action='store_true', help='Silence print statements during training')
parser.add_argument('--ins_type', type=str, default='indset_400n_4a_0se')
parser.add_argument('--ins_config', type=str, default='500n')
parser.add_argument('--num_train_ins', type=int, default=2)
parser.add_argument('--num_valid_ins', type=int, default=2)
parser.add_argument('--ins_batch_size', type=int, default=2)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--interval', type=int, help='interval for validation', default=1)
parser.add_argument('--alpha', type=float, default=0.9)
parser.add_argument('--alpha_decay', type=float, default=0.99)
parser.add_argument('--timelimit', type=int, default=120)

optimConfig = []
maxReward = []
minReward = []
df = pd.DataFrame(columns=['episodes', 'frame', 'rewards'])

if __name__ == '__main__':
    args = parser.parse_args()
    assert args.n % 2 == 0
    assert args.num_valid_ins % args.ins_batch_size == 0
    device = set_device_seed(args)

    # 加载训练集、测试集
    # 从/data/instances文件夹加载args.num_train_ins数量的训练集
    train_dirs = [os.path.join(f'../data/instances/{args.ins_type}/train_{args.ins_config}',
                               f'instance_{i+1}.lp') for i in range(args.num_train_ins)]
    valid_dirs = [os.path.join(f'../data/instances/{args.ins_type}/valid_{args.ins_config}',
                               f'instance_{i+1}.lp') for i in range(args.num_valid_ins)]
    

    # 加载train_vals、valid_vals，由evaluator.py文件预处理的数据集信息
    # 只有简单的名称信息，而不是像监督学习那样需要使用强分支方法生成监督学习数据集
    # len(train_vals) = len(train_vals) = 10000
    # len(valid_vals) = len(valid_vals) = 5000
    train_vals = pickle.load(open(f'../results/{args.ins_type}/train_{args.ins_config}/vals.pkl', 'rb'))
    valid_vals = pickle.load(open(f'../results/{args.ins_type}/valid_{args.ins_config}/vals.pkl', 'rb'))
    
    # 训练环境和测试环境没有模型，只有数据集、batch_size、种子、时间限制，可见环境只与scip相交互
    # 加载训练环境 src/utils.py/scip_env()
    env = scip_env(train_dirs, train_vals, args.ins_batch_size, scip_seed=args.seed, seed=args.seed, timelimit=args.timelimit)

    # 加载测试环境 src/utils.py/scip_env()
    valid_env = scip_env(valid_dirs, valid_vals, args.ins_batch_size, scip_seed=args.seed, seed=args.seed, timelimit=args.timelimit)

    # 加载agent src/utils.py/BranchAgent
    agent = BranchAgent(lr=args.sigma, device=None, check_point=args.restore)

    # check_point 似乎是输出文件
    chkpt_dir = f'../check_points/{args.ins_type}/'
    if not os.path.exists(chkpt_dir):
        os.makedirs(chkpt_dir)
    

    # 开始训练

    train_loop(args, agent, env, valid_env, chkpt_dir)