import numpy as np
import torch
import time
import torch.nn.functional as F
import torch.optim as optim
import cp_solver
import random
import os
import pickle
from copy import deepcopy
from collections import Counter
from collections import OrderedDict
from utils import ReplayBuffer


class DQN:
    def __init__(self,
                 envs,
                 path,
                 graph_save_loc,
                 network=None,
                 IAP=True,
                 # Initial network parameters.
                 init_network_params=None,  # +
                 init_weight_std=None,  # +

                 # DQN parameters
                 double_dqn=True,
                 gamma=0.99,
                 update_exploration=True,
                 initial_exploration_rate=1,
                 final_exploration_rate=0.05,
                 final_exploration_step=100000,
                 update_target_frequency=1000,

                 # Test
                 evaluate=True,
                 test_envs=None,
                 test_episodes=50,
                 test_frequency=10000,  # 2000
                 test_score_save_path='test_scores',
                 test_color_save_path='color_nums',
                 test_accurary_path='test_accurary',

                 # Replay buffer
                 replay_start_size=500,  # 50000
                 replay_buffer_size=5000,  # 1000000
                 minibatch_size=1,
                 update_frequency=16,

                 # learning rate
                 update_learning_rate=True,
                 initial_learning_rate=0,
                 peak_learning_rate=1e-3,
                 peak_learning_rate_step=10000,
                 final_learning_rate=5e-5,
                 final_learning_rate_step=200000,

                 # regularization
                 weight_decay=0,
                 loss="mse",

                 # Loss function
                 adam_epsilon=1e-8,
                 save_network_frequency=100000,
                 ):

        self.graph_save_loc = graph_save_loc
        self.double_dqn = double_dqn

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.network = network().to(self.device)

        self.target_network = network().to(self.device)
        self.target_network.load_state_dict(self.network.state_dict())

        for param in self.target_network.parameters():  # 目标网络无需反向传播求梯度
            param.requires_grad = False

        self.replay_start_size = replay_start_size
        self.replay_buffer_size = replay_buffer_size
        self.gamma = gamma
        self.update_exploration = update_exploration
        self.initial_exploration_rate = initial_exploration_rate
        self.epsilon = self.initial_exploration_rate
        self.final_exploration_rate = final_exploration_rate
        self.final_exploration_step = final_exploration_step

        self.minibatch_size = minibatch_size

        self.update_target_frequency = update_target_frequency
        self.update_frequency = update_frequency

        self.initial_learning_rate = initial_learning_rate
        self.learning_rate = initial_learning_rate
        self.update_learning_rate = update_learning_rate
        self.peak_learning_rate = peak_learning_rate
        self.peak_learning_rate_step = peak_learning_rate_step
        self.final_learning_rate = final_learning_rate
        self.final_learning_rate_step = final_learning_rate_step

        self.weight_decay = weight_decay

        self.adam_epsilon = adam_epsilon
        self.loss = F.mse_loss
        self.optimizer = optim.Adam(self.network.parameters(), lr=self.initial_learning_rate, eps=self.adam_epsilon,
                                    weight_decay=self.weight_decay)

        if type(envs) != list:
            envs = [envs]
        self.envs = envs
        self.env = self.get_sample_env()

        self.replay_buffers = {}
        for n in set([env.node_num for env in self.envs]):  # for n in set([20]). Bing
            self.replay_buffers[n] = ReplayBuffer(self.replay_buffer_size)
        self.replay_buffer = self.get_replay_buffer()

        self.init_network_params = init_network_params
        self.init_weight_std = init_weight_std
        if self.init_network_params != None:
            print("Pre-loading network parameters from {}.\n".format(init_network_params))
            self.load(init_network_params)
        else:
            if self.init_weight_std != None:
                def init_weights(m):
                    if type(m) == torch.nn.Linear:
                        print("Setting weights for", m)
                        m.weight.normal_(0, init_weight_std)

                with torch.no_grad():
                    self.network.apply(init_weights)

        if type(test_envs) != list:
            test_envs = [test_envs]
        self.test_envs = test_envs
        self.test_env = self.get_random_test_env()
        self.evaluate = evaluate
        self.test_episodes = test_episodes
        self.test_frequency = test_frequency
        self.test_score_save_path = test_score_save_path
        self.test_color_save_path = test_color_save_path
        self.test_accurary_path = test_accurary_path

        self.save_network_frequency = save_network_frequency
        self.path = path
        # self.losses_save_path = os.path.join("/huangke/Graph-Coloring/experiments/ER_n20_p25//network", "losses.pkl")
        self.losses_save_path = os.path.join(self.path + '/network', "losses.pkl")

        self.IAP = IAP
        self.allowed_action_state = self.env.get_allowed_action_states()

    def load(self, path):
        self.network.load_state_dict(torch.load(path, map_location=self.device))

    def get_sample_env(self):
        env = random.sample(self.envs, k=1)[0]
        return env

    def get_random_test_env(self):
        env = random.sample(self.test_envs, k=1)[0]
        return env

    def get_replay_buffer(self):
        return self.replay_buffers[self.env.node_num]

    @torch.no_grad()
    def predict(self, states):
        # print("states:", states.shape)
        q = self.network(states)

        q.squeeze_(-1)
        # allowed_actions_mask = torch.as_tensor(states[0] == -1)
        if self.IAP:
            if q.dim() == 1:
                actions = q.argmax()
            else:
                actions = q.argmax(1, True).squeeze(1).cpu().numpy()
        else:
            if q.dim() == 1:
                x = (states[0, :] == self.allowed_action_state).nonzero()  # 未着色节点序号
                actions = x[q[x].argmax().item()].item()  # 未着色节点中取Q最大的，作为下一个点
            else:
                disallowed_actions_mask = (states[:, :, 0] != self.allowed_action_state)
                qs_allowed = q.masked_fill(disallowed_actions_mask, -10000)  # 未着色节点序号
                actions = qs_allowed.argmax(1, True).squeeze(1).cpu().numpy()
            return actions
        return actions

    def select_node(self, state, is_ready_training=True):
        if is_ready_training and random.uniform(0, 1) >= self.epsilon:
            # if is_ready_training :
            # 选择Q值最大的节点
            action_node = self.predict(state)

            is_predict = True
        else:
            is_predict = False
            if self.IAP:
                action_node = np.random.choice(self.env.node_num, 1, replace=False)
            else:
                x = (state[0, :] == self.allowed_action_state).nonzero()  # 收集未着色顶点
                action_node = x[np.random.randint(0, len(x))].item()  # 未着色顶点中随机取一个来着色
                is_predict = False
        return int(action_node), is_predict

    @torch.no_grad()
    def evaluate_agent(self, cp_colors, batch_size=None):
        test_scores = []
        color_nums = []
        count = 0
        losses = []
        for i in range(self.test_episodes):
            test_env = self.get_random_test_env()
            s, adj = test_env.reset()

            state = np.vstack((s, adj))
            done = False
            score = 0
            actions = []
            # loss_test_episode = []

            while not done:
                action = self.predict(torch.FloatTensor(state).to(self.device).float())
                actions.append(action)

                state_next, reward, done = test_env.step(action)

                transition = state, action, reward, state_next, done
                score += reward
                state = state_next

            if (test_env.state[0] != -1).all():
                count += 1
                color_nums.append(test_env.color_num)
            else:
                color_nums.append(10)
            test_scores.append(score)

        eq = 0
        for i in range(len(cp_colors)):
            if cp_colors[i] >= color_nums[i]:
                eq += 1
        accuracy = eq / len(color_nums)
        print(Counter(color_nums))
        return np.mean(test_scores), np.mean(color_nums), accuracy  # , losses

    def update_epsilon(self, timestep):
        eps = self.initial_exploration_rate - (self.initial_exploration_rate - self.final_exploration_rate) * (
                timestep / self.final_exploration_step
        )
        self.epsilon = max(eps, self.final_exploration_rate)  # self.epsilon是取点的随机率 Bing
        # print(f"**self.epsilon = {self.epsilon}**")

    def update_lr(self, timestep):
        if timestep <= self.peak_learning_rate_step:
            lr = self.initial_learning_rate - (self.initial_learning_rate - self.peak_learning_rate) * (
                    timestep / self.peak_learning_rate_step
            )
        elif timestep <= self.final_learning_rate_step:
            lr = self.peak_learning_rate - (self.peak_learning_rate - self.final_learning_rate) * (
                    (timestep - self.peak_learning_rate_step) / (
                    self.final_learning_rate_step - self.peak_learning_rate_step)
            )
        else:
            lr = None

        if lr is not None:
            for g in self.optimizer.param_groups:
                g['lr'] = lr

    def get_random_env(self):
        return random.sample(self.envs, k=1)[0]

    def get_replay_buffer_for_env(self, env):
        return self.replay_buffers[env.node_num]

    def get_random_replay_buffer(self):
        return random.sample(list(self.replay_buffers.items()), k=1)[0][1]

    def learn(self, timesteps):
        # s 初始状态,adj邻接矩阵
        s, adj = self.env.reset()  # 初始化state

        state = torch.as_tensor(np.vstack((s, adj)), dtype=torch.float)

        score = 0
        losses = []
        losses_eps = []
        is_ready_training = False
        test_scores = []
        test_color_nums = []
        accuracy_lt = []

        cp_colors = cp_solver.get_opt(self.graph_save_loc)  # 利用谷歌工具包获得self.graph_save_loc中图的最优解
        # # 指定保存的完整路径
        # file_path = '/root/autodl-tmp/N2GraphColor/ndarry/array_file.npy'  # 根据实际情况修改路径
        #
        # # 保存 ndarray 到指定位置
        # np.save(file_path, cp_colors)
        #
        # cp_colors = np.load(file_path)
        t1 = time.time()
        count = 0

        # step = 0
        # seed = 100
        # timesteps 训练轮数
        for timestep in range(timesteps):
            if not is_ready_training:
                if self.replay_start_size <= len(self.replay_buffer):  # self.replay_buffer长度在当前for循环中会不断增加 Bing
                    is_ready_training = True

            # --------------------------------------select--------------------------------------
            action, is_predict = self.select_node(state.to(self.device).float(), is_ready_training=True)

            # Update epsilon
            if self.update_exploration:
                self.update_epsilon(timestep)

            # Update learning rate
            if self.update_learning_rate:
                self.update_lr(timestep)

            state_next, reward, done = self.env.step(action)

            score += reward

            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            state_next = torch.as_tensor(state_next, dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)

            self.replay_buffer.add(state, action, reward, state_next, done)

            if done:
                loss_str = "{:.2e}".format(np.mean(losses_eps)) if is_ready_training else "N/A"
                print("timestep : {}, episode time: {}, score : {}, mean loss: {}, time : {} s, color_num:{}".format(
                    (timestep + 1),
                    self.env.current_step,
                    round(float(score), 4),
                    loss_str,
                    round(time.time() - t1, 3),
                    self.env.color_num,
                ))

                self.env = self.get_random_env()
                self.replay_buffer = self.get_replay_buffer_for_env(self.env)
                s, adj = self.env.reset()

                state = torch.as_tensor(np.vstack((s, adj)))
                score = 0
                losses_eps = []
                t1 = time.time()
            else:
                state = state_next

            if is_ready_training:
                # Update the main network
                if timestep % self.update_frequency == 0:
                    # Sample a batch of transitions
                    transitions = self.get_random_replay_buffer().sample(self.minibatch_size, self.device)
                    # --------------------------------------train_step--------------------------------------
                    loss = self.train_step(transitions)
                    losses.append([timestep, loss])
                    losses_eps.append(loss)

                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep + 1) % self.save_network_frequency == 0 and is_ready_training:
                torch.save(self.network.state_dict(),
                           r"./experiments/ER_n60_p15/network_new.pth")
            if (timestep + 1) % self.test_frequency == 0 and self.evaluate and is_ready_training:
                test_score, test_color_num, test_accuracy = self.evaluate_agent(cp_colors=cp_colors)
                print('\nTest score: {}, color_num:{}, accurary:{}\n'.format(np.round(test_score, 4), test_color_num,
                                                                             np.round(test_accuracy, 4)))

                best_network2 = all([test_accuracy <= accurary for accurary in [test_accuracy]])
                best_network1 = all([test_color_num <= color_num for color_num in test_color_nums])
                best_network = all([test_score >= score for score in test_scores])

                if best_network:
                    count += 1
                    torch.save(self.network.state_dict(), self.path + "/network_N2_best.pth")
                    step1 = timestep
                    max_score_color = test_color_num
                    score_accurary = test_accuracy

                if best_network1:
                    torch.save(self.network.state_dict(), self.path + "/network_N2_best_co.pth")
                    step2 = timestep
                    min_color_score = test_score
                    color_accurary = test_accuracy

                if best_network2:
                    torch.save(self.network.state_dict(), self.path + "/network_N2_best_ac.pth")
                    step3 = timestep
                    max_accurary_score = test_score
                    accurary_color_num = test_color_num

                test_scores.append(test_score)
                test_color_nums.append(test_color_num)
                accuracy_lt.append(test_accuracy)
                print(np.round(max_score_color, 4), np.round(max(test_scores), 4), np.round(score_accurary, 4), step1)
                print(np.round(min(test_color_nums), 4), np.round(min_color_score, 4), np.round(color_accurary, 4),
                      step2)
                print(np.round(accurary_color_num, 4), np.round(max_accurary_score, 4), np.round(max(accuracy_lt), 4),
                      step3)

            if (timestep + 1) % self.save_network_frequency == 0 and is_ready_training:
                torch.save(self.network.state_dict(), self.path + "/network/" + str(timestep + 1) + "network_N2.pth")

        test_score_path = self.test_score_save_path
        test_color_path = self.test_color_save_path
        test_accurary_path = self.test_accurary_path

        with open(test_color_path, 'wb+') as output:
            pickle.dump(np.array(test_color_nums), output, pickle.HIGHEST_PROTOCOL)

        with open(test_score_path, 'wb+') as output:
            pickle.dump(np.array(test_scores), output, pickle.HIGHEST_PROTOCOL)

        with open(test_accurary_path, 'wb+') as output:
            pickle.dump(np.array(accuracy_lt), output, pickle.HIGHEST_PROTOCOL)

        with open(self.losses_save_path, 'wb+') as output:
            pickle.dump(np.array(losses), output, pickle.HIGHEST_PROTOCOL)

        torch.save(self.network.state_dict(), self.path + "/network_N2_last.pth")

    def train_step(self, transitions):
        states, actions, rewards, states_next, dones = transitions
        # print("state:", states, "actions", actions, "rewards", rewards, "states_next", states_next, "dones", dones)
        if self.IAP:
            with torch.no_grad():
                if self.double_dqn:
                    greedy_actions = self.network(states_next.float()).argmax(1, True)
                    states_next.transpose_(-1, -2)
                    q_value_target = self.target_network(states_next.float()).gather(1, greedy_actions)

                else:
                    q_value_target = self.target_network(states_next.float()).max(1, True)[0]
        else:
            target_preds = self.target_network(states_next.float())
            states_next.transpose_(-1, -2)
            disallowed_actions_mask = (states_next[:, 0, :] != self.allowed_action_state)
            with torch.no_grad():
                if self.double_dqn:
                    network_preds = self.network(states_next.float())
                    network_preds_allowed = network_preds.masked_fill(disallowed_actions_mask, -10000)
                    greedy_actions = network_preds_allowed.argmax(1, True)

                    q_value_target = target_preds.gather(1, greedy_actions)
                else:
                    q_value_target = target_preds.masked_fill(disallowed_actions_mask, -10000).max(1, True)[0]

        # Calculate TD target  # Bellman方程，求累计奖励(Bing)
        td_target = rewards + (1 - dones) * self.gamma * q_value_target

        # Calculate Q value
        q_value = self.network(states.float()).gather(1, actions)

        loss = self.loss(q_value, td_target, reduction='mean')

        self.optimizer.zero_grad()

        loss.backward()

        self.optimizer.step()

        return loss.item()
