import os
import json
import tqdm
import torch
import numpy as np
from hyperGraph_utils.utils import OurHypergraph
from LGCA2CNet import LGCA2CNet, collate_LGC
from torch.distributions import Categorical
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler


class LGCAgentPPO(torch.nn.Module):
    def __init__(self, data_name, ProcessedDataFile_path, hidden_dimension, num_egat_heads,
                 num_egat_layers, num_of_LayerHypergraph, drop_rate=0, residual=True,
                 have_fp=True, have_structure=True, gamma=0.99, eps_clip=0.2, value_coefficient=0.5,
                 entropy_coefficient=0.01, learning_rate=0.001, device=torch.device('cpu')):
        super(LGCAgentPPO, self).__init__()
        # hypergraph
        LG_hypergraph_info_path = os.path.join(ProcessedDataFile_path, data_name, 'hypergraph.json')
        with open(LG_hypergraph_info_path, 'r') as f:
            self.hg_json = json.load(f)
        self.hypergraph_dhg = OurHypergraph(num_v=self.hg_json['num_v'],  # int
                                            e_list=self.hg_json["e_list"],  # [[], [], ...]
                                            x_tensor=torch.tensor(self.hg_json['v_fp'],
                                                                  dtype=torch.float32))  # num_v x 2048
        self.hypergraph_dhg.our_to(device)

        self.LGCA2CNet = LGCA2CNet(hidden_dimension=hidden_dimension, num_egat_heads=num_egat_heads,
                                   num_egat_layers=num_egat_layers, lg_hypergraph_dhg=self.hypergraph_dhg,
                                   num_of_LayerHypergraph=num_of_LayerHypergraph,
                                   drop_rate=drop_rate, residual=residual,
                                   have_fp=have_fp, have_structure=have_structure)
        self.LGCA2CNet.to(device)

        self.optimizer = torch.optim.Adam(self.LGCA2CNet.parameters(), lr=learning_rate)
        self.mseLoss = torch.nn.MSELoss(reduction='none')
        self.bceLoss = torch.nn.BCELoss()
        self.device = device

        self.gamma = gamma
        self.eps_clip = eps_clip
        self.value_coefficient = value_coefficient
        self.entropy_coefficient = entropy_coefficient

        self.buffer = []

    def store_transition(self, transition):
        # transition = (state, action, p_of_a, r, next_state, done)  # ([], int, float, float, [], int)
        self.buffer.append(transition)

    def clear_buffer(self):
        self.buffer = []

    def save_param(self, save_path):
        torch.save(self.LGCA2CNet.state_dict(), save_path)
        print('save success!')

    def load_param(self, ckpt_path):
        self.LGCA2CNet.load_state_dict(torch.load(ckpt_path))
        print('load success!')

    def select_action(self, one_state):
        # state = [product_dgl, product_fp, RcNodeIdx_list, mask, t]
        batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask = collate_LGC([one_state])
        batch_product_dgl, batch_product_fp, batch_mask = batch_product_dgl.to(self.device), \
                                                          batch_product_fp.to(self.device), \
                                                          batch_mask.to(self.device)

        self.LGCA2CNet.eval()
        with torch.no_grad():
            # v+1
            rst_policy = self.LGCA2CNet.policy(batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask,
                                               logits=False).reshape(-1)

        dist = Categorical(rst_policy)
        action = dist.sample().item()

        if len(rst_policy) - 1 == action:
            action = -1

        p_of_action = rst_policy[action].item()
        return action, p_of_action, rst_policy

    def select_action_infer(self, one_state, additional_mask):  # additional_mask 1x(v+1)
        # state = [product_dgl, product_fp, RcNodeIdx_list, mask, t]
        batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask = collate_LGC([one_state])
        batch_product_dgl, batch_product_fp, batch_mask = batch_product_dgl.to(self.device), \
                                                          batch_product_fp.to(self.device), \
                                                          batch_mask.to(self.device)

        self.LGCA2CNet.eval()
        with torch.no_grad():
            # 1x(v+1)
            rst_policy = self.LGCA2CNet.policy(batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask,
                                               logits=True)

        rst_policy = rst_policy + additional_mask * -1e9  # 1x(v+1)
        rst_policy = torch.nn.functional.softmax(rst_policy, dim=-1).reshape(-1)  # v+1

        dist = Categorical(rst_policy)
        action = dist.sample().item()

        if len(rst_policy) - 1 == action:
            action = -1

        p_of_action = rst_policy[action].item()
        return action, p_of_action, rst_policy

    def update(self, num_epochs, batch_size, imitation=False):
        if imitation:
            # (state, action, p_of_a, r, next_state, done)  # ([], int, float, float, [], int)
            list_state, list_action, list_p_of_a, list_r, list_next_state, list_done = map(list, zip(*self.buffer))
            buffer_action = torch.tensor(list_action, dtype=torch.long).reshape(-1, 1)  # Mx1

            self.LGCA2CNet.train()
            loss_list = []
            for _ in tqdm.tqdm(range(num_epochs), leave=False):
                for index in BatchSampler(SubsetRandomSampler(range(len(self.buffer))), batch_size=batch_size,
                                          drop_last=False):
                    # for i in range(len(self.buffer)):
                    #     index = [i]

                    # batch_state
                    samples_states = [list_state[i] for i in index]  # [[], [], ...]
                    batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask = collate_LGC(samples_states)
                    batch_product_dgl, batch_product_fp, batch_mask = batch_product_dgl.to(self.device), \
                                                                      batch_product_fp.to(self.device), \
                                                                      batch_mask.to(self.device)

                    # new_p
                    batch_action = buffer_action[index]  # bx1
                    batch_policy = self.LGCA2CNet.policy(batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask,
                                                         logits=False)  # bx(v+1)
                    batch_new_p = batch_policy[range(len(index)), batch_action.reshape(-1)].reshape(-1, 1)  # bx1

                    # label
                    label = torch.ones_like(batch_new_p).detach()  # bx1

                    loss = self.bceLoss(batch_new_p, label)
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.LGCA2CNet.parameters(), 1)
                    self.optimizer.step()
                    loss_list.append(loss.cpu().detach().item())
            return np.mean(loss_list)

        else:
            # (state, action, p_of_a, r, next_state, done)  # ([], int, float, float, [], int)
            list_state, list_action, list_p_of_a, list_r, list_next_state, list_done = map(list, zip(*self.buffer))
            buffer_action = torch.tensor(list_action, dtype=torch.long).reshape(-1, 1)  # Mx1
            buffer_p_of_a = torch.tensor(list_p_of_a, dtype=torch.float32).reshape(-1, 1)  # Mx1
            buffer_r = torch.tensor(list_r, dtype=torch.float32).reshape(-1, 1)  # Mx1
            buffer_done = torch.tensor(list_done, dtype=torch.float32).reshape(-1, 1)  # Mx1

            Gt = []
            discounted_r = 0
            for reward, d in zip(reversed(buffer_r), reversed(buffer_done)):
                if d:
                    discounted_r = 0
                discounted_r = reward + self.gamma * discounted_r
                Gt.insert(0, discounted_r)  # insert in front, cannot use append

            buffer_Gt = torch.tensor(Gt, dtype=torch.float32).reshape(-1, 1)  # Mx1

            loss_list = []
            self.LGCA2CNet.train()
            for _ in tqdm.tqdm(range(num_epochs), leave=False):
                for index in BatchSampler(SubsetRandomSampler(range(len(self.buffer))), batch_size=batch_size,
                                          drop_last=False):
                    # for i in range(len(self.buffer)):
                    #     index = [i]

                    # batch_state
                    samples_states = [list_state[i] for i in index]  # [[], [], ...]
                    batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask = collate_LGC(samples_states)
                    batch_product_dgl, batch_product_fp, batch_mask = batch_product_dgl.to(self.device), \
                                                                      batch_product_fp.to(self.device), \
                                                                      batch_mask.to(self.device)

                    # batch_Gt
                    batch_Gt = buffer_Gt[index]  # bx1
                    if len(batch_Gt) > 1:
                        batch_Gt = (batch_Gt - batch_Gt.mean()) / (batch_Gt.std() + 1e-5)
                    batch_Gt = batch_Gt.to(self.device)  # bx1

                    # pre_v
                    batch_pre_v = self.LGCA2CNet.value(batch_product_dgl, batch_product_fp, batch_lg_list)  # bx1

                    # advantage
                    batch_advantage = batch_Gt - batch_pre_v
                    batch_advantage = batch_advantage.detach()  # bx1

                    # old_p
                    batch_old_p = buffer_p_of_a[index].detach()  # bx1
                    batch_old_p = batch_old_p.to(self.device)  # bx1

                    # new_p
                    batch_action = buffer_action[index]  # bx1
                    batch_policy = self.LGCA2CNet.policy(batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask,
                                                         logits=False)  # bx(v+1)
                    batch_new_p = batch_policy[range(len(index)), batch_action.reshape(-1)].reshape(-1, 1)  # bx1

                    # entropy
                    dist = Categorical(batch_policy)
                    entropy = dist.entropy().reshape(-1, 1)  # bx1

                    # loss
                    ratio = torch.exp(torch.log(batch_new_p) - torch.log(batch_old_p))  # a/b == exp(log(a)-log(b))  bx1

                    surr1 = ratio * batch_advantage  # bx1
                    surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * batch_advantage  # bx1
                    # loss = -torch.min(surr1, surr2).mean() + \
                    #        self.value_coefficient * (self.mseLoss(batch_pre_v, batch_Gt.detach()).mean()) - \
                    #        self.entropy_coefficient * (entropy.mean())

                    loss = -torch.min(surr1, surr2) + \
                           self.value_coefficient * self.mseLoss(batch_pre_v, batch_Gt.detach()) - \
                           self.entropy_coefficient * entropy
                    loss = loss.mean()

                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.LGCA2CNet.parameters(), 1)
                    self.optimizer.step()
                    loss_list.append(loss.cpu().detach().item())
            return np.mean(loss_list)


