
import math
import random
import numpy as np
import os
import sys
from tqdm import tqdm
# sys.path.append('..')
import higher

from collections import namedtuple
import argparse
from copy import deepcopy
from itertools import count, chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils import *
from sum_tree import SumTree

#TODO select env
from RL.env_binary_question import BinaryRecommendEnv
from RL.env_enumerated_question import EnumeratedRecommendEnv
from RL.RL_evaluate import dqn_evaluate
# from gcn import GraphEncoder
from torch.autograd import Variable
import time
from torch.distributions import Categorical
import warnings
import IPython
from encoder import TransGate, NegRanker
from reward_function import Intrinsic_Reward, Reward_Ensemble
from min_norm_solvers import MinNormSolver

warnings.filterwarnings("ignore")
EnvDict = {
    LAST_FM: BinaryRecommendEnv,
    LAST_FM_STAR: BinaryRecommendEnv,
    YELP: EnumeratedRecommendEnv,
    YELP_STAR: BinaryRecommendEnv
    }
FeatureDict = {
    LAST_FM: 'feature',
    LAST_FM_STAR: 'feature',
    YELP: 'large_feature',
    YELP_STAR: 'feature'
}

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'next_cand'))

def get_entropy(probs):
    entropy = []
    for p in probs:
        entropy.append(-torch.sum(p * torch.log(p)).item())
    return np.mean(entropy), entropy

class DQN(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=100):
        super(DQN, self).__init__()
        # V(s)
        self.fc2_value = nn.Linear(hidden_size, hidden_size)
        self.out_value = nn.Linear(hidden_size, 1)
        # Q(s,a)
        self.fc2_advantage = nn.Linear(hidden_size + action_size, hidden_size)   
        self.out_advantage = nn.Linear(hidden_size, 1)

    def forward(self, x, y, choose_action=True):
        """
        :param x: encode history [N*L*D]; y: action embedding [N*K*D]
        :return: v: action score [N*K]
        """
        # V(s)
        value = self.out_value(F.relu(self.fc2_value(x))).squeeze(dim=2) #[N*1*1]
        
        # Q(s,a)
        if choose_action:
            x = x.repeat(1, y.size(1), 1)
        
        state_cat_action = torch.cat((x, y),dim=2)
        advantage = self.out_advantage(F.relu(self.fc2_advantage(state_cat_action))).squeeze(dim=2) #[N*K]
        return advantage


class Agent(nn.Module):
    def __init__(self, device, state_size, action_size, hidden_size, gcn_net, learning_rate, l2_norm, args, reward_size=10):
        super(Agent, self).__init__()
        
        self.steps_done = 0
        self.device = device
        self.gcn_net = gcn_net
        self.policy_net = DQN(state_size, action_size, hidden_size)
        self.target_net = DQN(state_size, action_size, hidden_size)

        if args.num_rew_ensemble > 1:
            self.int_rew_func = Reward_Ensemble(args, state_size, action_size, num_ensemble=args.num_rew_ensemble)
        else:
            self.int_rew_func = Intrinsic_Reward(args, state_size, action_size)

        self.optimizer = optim.SGD(chain(self.policy_net.parameters(),self.gcn_net.parameters()), lr=args.inner_lr, weight_decay=l2_norm)
        self.rew_optimizer = optim.Adam(self.int_rew_func.parameters(), lr=args.outer_lr, weight_decay=l2_norm)
    
    def select_action(self, state, cand_feature, cand_item, action_space, is_test=False, is_last_turn=False, fast_policy=None, inner=False, raw_policy=None, item_rank=True):
        state_emb = self.gcn_net([state])

        cand_feature = torch.LongTensor([cand_feature]).to(self.device)
        cand_item = torch.LongTensor([cand_item]).to(self.device)

        cand_feat_emb = self.gcn_net.embedding(cand_feature)
        cand_item_emb = self.gcn_net.embedding(cand_item)

        cand = torch.cat((cand_feature, cand_item), 1)
        cand_emb = torch.cat((cand_feat_emb, cand_item_emb), 1)

        self.steps_done += 1
        if fast_policy is not None:
            action_value = fast_policy(state_emb, cand_emb)
        else:
            action_value = self.policy_net(state_emb, cand_emb)
        
        if raw_policy is not None:
            raw_action_value = raw_policy(state_emb, cand_emb)
            raw_prob = Categorical(raw_action_value.softmax(1))


        prob = Categorical(action_value.softmax(1))
        if is_test:
            if (len(action_space[1]) <= 10 or is_last_turn):
                return torch.tensor(action_space[1][0], device=self.device, dtype=torch.long), action_space[1]
            action = cand[0][action_value.argmax().item()]
            sorted_actions = cand[0][action_value.sort(1, True)[1].tolist()]
            return action, sorted_actions.tolist()
        else:
            action = prob.sample()
            action_2 = prob.sample()
            action_2 = cand[0][action_2]

            log_prob = prob.log_prob(action)
            
            if raw_policy is not None:
                raw_log_prob = raw_prob.log_prob(action)
            else:
                raw_log_prob = None

            action = cand[0][action]
            sorted_actions = cand[0][action_value.sort(1, True)[1].tolist()]
            return action, sorted_actions.tolist(), log_prob, raw_log_prob, prob.probs, cand, action_2
    
    def trajectory_likelihood(self, actions, cands, states, fast_policy=None, item_rank=True):
        log_probs = []

        for i in range(len(states)):
            action = actions[i]
            cand = cands[i]
            state = states[i]

            state_emb = self.gcn_net([state])
            cand_emb = self.gcn_net.embedding(cand)

            if fast_policy is not None:
                action_value = fast_policy(state_emb, cand_emb)
            else:
                action_value = self.policy_net(state_emb, cand_emb)
            
            prob = Categorical(action_value.softmax(1))
            action_idx = torch.eq(cand[0], action[0]).nonzero(as_tuple=True)[0]
            log_prob = prob.log_prob(action_idx)
            log_probs.append(log_prob)
        return log_probs

    
    def intrinsic_reward(self, state, action, next_state, turn_sign, target_item, user_id):
        state_emb = self.gcn_net([state])
        action_emb = self.gcn_net.embedding(action)
        next_state_emb = self.gcn_net([next_state])

        target_emb = torch.LongTensor([target_item]).to(self.device)
        target_emb = self.gcn_net.embedding(target_emb)
        
        user_id = torch.LongTensor([user_id]).to(self.device)
        user_emb = self.gcn_net.embedding(user_id)
        int_rew, neg_int_rew = self.int_rew_func(state_emb, action_emb, next_state_emb,
                                                turn_sign, target_emb, user_emb)
        
        return int_rew, neg_int_rew
    
    def policy_score(self, state, item):
        state_emb = self.gcn_net([state])
        item = torch.LongTensor([[item]]).to(self.device)
        item_emb = self.gcn_net.embedding(item)
        score = self.policy_net(state_emb, item_emb)
        return score
            

    def save_model(self, data_name, filename, epoch_user):
        save_rl_agent(dataset=data_name, model={'policy': self.policy_net.state_dict(), 'gcn': self.gcn_net.state_dict()}, filename=filename, epoch_user=epoch_user)
    
    def load_model(self, data_name, filename, epoch_user):
        model_dict = load_rl_agent(dataset=data_name, filename=filename, epoch_user=epoch_user)
        self.policy_net.load_state_dict(model_dict['policy'])
        self.gcn_net.load_state_dict(model_dict['gcn'])
    

def run_episode(args, env, agent, i_episode, fast_policy=None, reset_ui=True, outer=False, item_ranker=None, inner=False, raw_policy=None, reset_init_query=True):
    print('\n================new tuple:{}===================='.format(i_episode))
    if not args.fix_emb:
        state, cand, action_space = env.reset(agent.gcn_net.embedding.weight.data.cpu().detach().numpy(), reset_ui=reset_ui)  # Reset environment and record the starting state
    else:
        state, cand, action_space = env.reset(reset_ui=reset_ui, reset_init_query=reset_init_query) 

    epi_reward = 0
    is_last_turn = False

    reward_list = []
    rank_reward_list = []
    
    action_list = []
    cand_list = []
    state_list = []

    log_prob_list = Variable(torch.Tensor()).to(args.device)
    raw_log_prob_list = Variable(torch.Tensor()).to(args.device)
    
    int_rew_list = Variable(torch.Tensor()).to(args.device)
    neg_rew_list = Variable(torch.Tensor()).to(args.device)
    
    hidden=None
    success = False

    target_item = env.target_item + env.user_length
    item_loss = []
    prob_list = []
    
    for t in count():   # user  dialog
        if t == 14:
            is_last_turn = True
        
        action, sorted_actions, log_prob, raw_log_prob, probs, cand_selected, action_2 = agent.select_action(state, cand[0], cand[1], action_space, is_last_turn=is_last_turn, fast_policy=fast_policy, inner=inner, raw_policy=raw_policy)
        prob_list.append(probs)

        action_list.append(action)
        cand_list.append(cand_selected)
        state_list.append(state)
        
        if raw_log_prob is not None:
            raw_log_prob_list = torch.cat([raw_log_prob_list, raw_log_prob.reshape(1)])
        
        if not args.fix_emb:
            next_state, next_cand, next_action_space, reward, done, success, turn_sign = env.step(action.item(), sorted_actions, agent.gcn_net.embedding.weight.data.cpu().detach().numpy())
        else:
            next_state, next_cand, next_action_space, reward, done, success, turn_sign = env.step(action.item(), sorted_actions)
        
        user_id = env.user_id
        int_rew, neg_int_rew = agent.intrinsic_reward(state, action, next_state, turn_sign, target_item, user_id)
        

        int_rew_list = torch.cat([int_rew_list, int_rew]) 
        neg_rew_list = torch.cat([neg_rew_list, neg_int_rew])

        epi_reward += reward
        reward_list.append(reward)
        log_prob_list = torch.cat([log_prob_list, log_prob.reshape(1)])
        
        next_target_idx = next_action_space[1].index(target_item)
        curr_target_idx = action_space[1].index(target_item)

        rank_reward = np.log(curr_target_idx + 1) - np.log(next_target_idx + 1)
        rank_reward = np.clip(rank_reward, -1, 1)
        rank_reward_list.append(rank_reward)

        state = next_state
        cand = next_cand
        action_space = next_action_space

        if done:
            break
    
    reward_list = Variable(torch.Tensor(reward_list)).to(args.device)
    rank_reward_list = Variable(torch.Tensor(rank_reward_list)).to(args.device)

    if success:
        rank_reward_list = reward_list + args.rank_lambda * rank_reward_list
    else:
        rank_reward_list = reward_list + 0 * rank_reward_list

    aug_reward_list = reward_list + args.int_lambda * int_rew_list

    if success:
        int_reg_loss = (int_rew_list.sum() - 1) ** 2
    else:
        int_reg_loss = (int_rew_list.sum()) ** 2
    print('Int rew: ', int_rew_list)

    for i in range(len(reward_list)-2, -1, -1):
        reward_list[i] = reward_list[i] + args.gamma * reward_list[i+1]
        aug_reward_list[i] = aug_reward_list[i] + args.gamma * aug_reward_list[i+1]
        rank_reward_list[i] = rank_reward_list[i] + args.gamma * rank_reward_list[i+1]

    return reward_list, aug_reward_list, log_prob_list, rank_reward_list, item_loss, success, \
                                int_reg_loss, raw_log_prob_list, prob_list, action_list, cand_list, state_list

class RolloutBuffer:
    def __init__(self, buffer_size=5000):
        self.actions = []
        self.cands = []
        self.states = []
        self.rewards = []

        self.buffer_size = 5000
        self.num_traj = 0
    
    def add(self, actions, cands, states):
        if self.num_traj < self.buffer_size:
            self.num_traj += 1
        else:
            self.actions.pop(0)
            self.cands.pop(0)
            self.states.pop(0)

        self.actions.append(actions)
        self.cands.append(cands)
        self.states.append(states)
        
    def sample(self):
        idx = np.random.randint(0, self.num_traj)
        return self.actions[idx], self.cands[idx], self.states[idx]



def train(args, kg, dataset, filename):
    env = EnvDict[args.data_name](kg=kg, dataset=dataset, data_name=args.data_name, embed=args.embed, seed=args.seed, max_turn=args.max_turn, cand_num=args.cand_num, cand_item_num=args.cand_item_num,
                       attr_num=args.attr_num, mode='train', ask_num=args.ask_num, entropy_way=args.entropy_method, fm_epoch=args.fm_epoch, args=args)
    set_random_seed(args.seed)
    embed = torch.FloatTensor(np.concatenate((env.ui_embeds, env.feature_emb, np.zeros((1,env.ui_embeds.shape[1]))), axis=0))
    
    gcn_net = TransGate(device=args.device, entity=embed.size(0), emb_size=embed.size(1), kg=kg, embeddings=embed, \
        fix_emb=args.fix_emb, seq=args.seq, gcn=args.gcn, hidden_size=args.hidden).to(args.device)

    agent = Agent(device=args.device, state_size=args.hidden, action_size=embed.size(1), \
        hidden_size=args.hidden, gcn_net=gcn_net, learning_rate=args.learning_rate, l2_norm=args.l2_norm, args=args).to(args.device)


    if args.load_rl_epoch != 0 :
        load_filename = 'train-data-{}-RL-cand_num-10-cand_item_num-10-embed-transe-seq-transformer-gcn-True-transgate-True-rank-False_PG'.format(args.data_name, args.entropy_method)

        print('Staring loading rl model in epoch {}'.format(args.load_rl_epoch))
        model_dict = load_rl_agent(dataset=args.data_name, filename=load_filename, epoch_user=args.load_rl_epoch)
        
        agent.gcn_net.load_state_dict(model_dict['gcn'])    
        agent.policy_net.load_state_dict(model_dict['policy'])

    
    if args.fix_encoder:
        policy_opt = optim.Adam(agent.policy_net.parameters(), lr=args.inner_lr, weight_decay=args.l2_norm)
        for param in agent.gcn_net.parameters():
            param.requires_grad = False
    else:
        policy_opt = optim.Adam(chain(agent.policy_net.parameters(), agent.gcn_net.parameters()), lr=args.inner_lr, weight_decay=args.l2_norm)

    meta_opt = optim.Adam(chain(agent.int_rew_func.parameters()), lr=args.outer_lr, weight_decay=args.l2_norm)
    test_performance = []
    if args.eval_num == 1:
        print('Staring loading rl model in epoch {}'.format(args.load_rl_epoch))
        model_dict = load_rl_agent(dataset=args.data_name, filename=filename, epoch_user=args.load_rl_epoch)
        agent.gcn_net.load_state_dict(model_dict['gcn'])    
        agent.policy_net.load_state_dict(model_dict['policy'])

        SR15_mean = dqn_evaluate(args, kg, dataset, agent, filename, 0)
        test_performance.append(SR15_mean)
        print('Test SR15: ', SR15_mean)
        return

    pos_traj_buffer = RolloutBuffer()
    neg_traj_buffer = RolloutBuffer()

    start_update = False
    for train_step in tqdm(range(1, args.max_steps+1)):
        agent.train()
        for i_episode in range(args.sample_times):
            meta_opt.zero_grad()

            with higher.innerloop_ctx(agent.policy_net, policy_opt, copy_initial_weights=False) as (fast_policy, inner_opt):
                int_reg_losses = 0

                loss_data = {}
                moo_loss = {}
                grads = {}
                

                reward_list, aug_reward_list, log_prob_list, _, _, success, int_reg_loss, _, _, _, _, _ = run_episode(args, env, agent, i_episode, fast_policy, reset_ui=True, inner=True)
                aug_loss = torch.sum(torch.mul(log_prob_list, aug_reward_list).mul(-1)) 
                int_reg_losses += int_reg_loss
                inner_opt.step(aug_loss)

                raw_agent = deepcopy(agent)
                if train_step > args.warmup_steps:
                    agent.policy_net.load_state_dict(fast_policy.state_dict())

                point_loss = 0
                for outer_iter in range(args.outer_aug_times):
                        if outer_iter == 0:
                            reward_list, aug_reward_list, log_prob_list, rank_reward_list, _, success, int_reg_loss, raw_log_prob_list, prob_list, action_list, cand_list, state_list = run_episode(args, env, agent, i_episode, fast_policy, reset_ui=True, outer=True, raw_policy=raw_agent.policy_net)
                        else:
                            reward_list, aug_reward_list, log_prob_list, rank_reward_list, _, success, int_reg_loss, raw_log_prob_list, prob_list, action_list, cand_list, state_list = run_episode(args, env, agent, i_episode, fast_policy, 
                                                                                                                                            reset_ui=False, outer=True, raw_policy=raw_agent.policy_net, reset_init_query=False)
                        int_reg_losses += int_reg_loss
                        
                        if reward_list[-1] == 1:
                            pos_traj_buffer.add(action_list, cand_list, state_list)
                        else:
                            neg_traj_buffer.add(action_list, cand_list, state_list)
                        
                        point_loss += torch.sum(torch.mul(log_prob_list, rank_reward_list).mul(-1))
                    
                point_loss /= args.outer_aug_times
                int_reg_losses /= args.outer_aug_times + 1

                    
                meta_opt.zero_grad()
                loss_data['point_loss'] = point_loss.data
                point_loss.backward(retain_graph=True)
                grads['point_loss'] = []
                moo_loss['point_loss'] = point_loss

                for param in agent.int_rew_func.parameters():
                    if param.grad is not None:
                        grads['point_loss'].append(Variable(param.grad.data.clone(), requires_grad=False))

                suc_loss = 0
                if pos_traj_buffer.num_traj > 0 and neg_traj_buffer.num_traj > 0:
                        start_update = True

                        pos_actions, pos_cands, pos_states = pos_traj_buffer.sample()
                        neg_actions, neg_cands, neg_states = neg_traj_buffer.sample()

                        pos_log_probs = agent.trajectory_likelihood(pos_actions, pos_cands, pos_states, fast_policy)
                        neg_log_probs = agent.trajectory_likelihood(neg_actions, neg_cands, neg_states, fast_policy)

                        pos_log_probs = torch.stack(pos_log_probs)
                        neg_log_probs = torch.stack(neg_log_probs)

                        neg_log_probs = neg_log_probs[:len(pos_log_probs)]

                        prob_pi_1 = torch.exp(pos_log_probs.sum() - torch.logsumexp(torch.stack([pos_log_probs.sum(), neg_log_probs.sum()]), 0))
                        prob_pi_1 = torch.clamp(prob_pi_1, max=1-1e-5, min=1e-5)  
                        suc_loss = -torch.log(prob_pi_1)

                        meta_opt.zero_grad()
                        loss_data['suc_loss'] = suc_loss.data
                        suc_loss.backward(retain_graph=True)
                        grads['suc_loss'] = []
                        moo_loss['suc_loss'] = suc_loss

                        for param in agent.int_rew_func.parameters():
                            if param.grad is not None:
                                grads['suc_loss'].append(Variable(param.grad.data.clone(), requires_grad=False))                      

                if start_update:
                    meta_opt.zero_grad() 
                    sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in ['point_loss', 'suc_loss']])
                    meta_loss = sol[0] * moo_loss['point_loss'] + sol[1] * moo_loss['suc_loss']
                    meta_loss += 0.01 * int_reg_losses
                    meta_loss.backward()
                    
                nn.utils.clip_grad_norm_(agent.int_rew_func.parameters(), 0.5)
                meta_opt.step()
            
        enablePrint() # Enable print function
        if train_step % args.eval_num == 0:
            SR15_mean = dqn_evaluate(args, kg, dataset, agent, filename, train_step)
            test_performance.append(SR15_mean)
            print('Test SR15: ', SR15_mean)
        
        if train_step % args.save_num == 0:
            agent.save_model(data_name=args.data_name, filename=filename, epoch_user=train_step)
    print(test_performance)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-seed', type=int, default=1, help='random seed.')
    parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
    parser.add_argument('--epochs', '-me', type=int, default=50000, help='the number of RL train epoch')
    parser.add_argument('--fm_epoch', type=int, default=0, help='the epoch of FM embedding')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size.')
    parser.add_argument('--gamma', type=float, default=0.999, help='reward discount factor.')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate.')
    parser.add_argument('--l2_norm', type=float, default=1e-6, help='l2 regularization.')
    parser.add_argument('--hidden', type=int, default=100, help='number of samples')
    parser.add_argument('--memory_size', type=int, default=50000, help='size of memory ')

    parser.add_argument('--data_name', type=str, default=LAST_FM, choices=[LAST_FM, LAST_FM_STAR, YELP, YELP_STAR],
                        help='One of {LAST_FM, LAST_FM_STAR, YELP, YELP_STAR}.')
    parser.add_argument('--entropy_method', type=str, default='weight_entropy', help='entropy_method is one of {entropy, weight entropy}')
    # Although the performance of 'weighted entropy' is better, 'entropy' is an alternative method considering the time cost.
    parser.add_argument('--max_turn', type=int, default=15, help='max conversation turn')
    parser.add_argument('--attr_num', type=int, help='the number of attributes')
    parser.add_argument('--mode', type=str, default='train', help='the mode in [train, test]')
    parser.add_argument('--ask_num', type=int, default=1, help='the number of features asked in a turn')
    parser.add_argument('--load_rl_epoch', type=int, default=0, help='the epoch of loading RL model')

    parser.add_argument('--sample_times', type=int, default=100, help='the epoch of sampling')
    parser.add_argument('--max_steps', type=int, default=500, help='max training steps')
    parser.add_argument('--eval_num', type=int, default=50, help='the number of steps to evaluate RL model and metric')
    parser.add_argument('--save_num', type=int, default=100, help='the number of steps to save RL model and metric')
    parser.add_argument('--observe_num', type=int, default=500, help='the number of steps to print metric')
    parser.add_argument('--cand_num', type=int, default=10, help='candidate sampling number')
    parser.add_argument('--cand_item_num', type=int, default=10, help='candidate item sampling number')
    parser.add_argument('--fix_emb', action='store_false', help='fix embedding or not')
    parser.add_argument('--embed', type=str, default='transe', help='pretrained embeddings')
    parser.add_argument('--seq', type=str, default='transformer', choices=['rnn', 'transformer', 'mean'], help='sequential learning method')
    parser.add_argument('--gcn', action='store_false', help='use GCN or not')
    parser.add_argument('--int_lambda', type=float, default=0.1)
    parser.add_argument('--rank_lambda', type=float, default=0.5)

    parser.add_argument('--inner_lr', type=float, default=1e-5)
    parser.add_argument('--outer_lr', type=float, default=1e-4)
    parser.add_argument('--fix_encoder', action='store_true')
    parser.add_argument('--inner_steps', type=int, default=1)
    parser.add_argument('--rnn_int_rew', action='store_true')
    
    parser.add_argument('--item_rank', action='store_true')
    parser.add_argument('--transgate', action='store_true')

    parser.add_argument('--inner_aug_times', type=int, default=0)
    parser.add_argument('--outer_aug_times', type=int, default=1)
    parser.add_argument('--num_rew_ensemble', type=int, default=4)
    parser.add_argument('--warmup_steps', type=int, default=10)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    args.device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
    print(args.device)
    print('data_set:{}'.format(args.data_name))
    kg = load_kg(args.data_name)
    #reset attr_num
    feature_name = FeatureDict[args.data_name]
    feature_length = len(kg.G[feature_name].keys())
    print('dataset:{}, feature_length:{}'.format(args.data_name, feature_length))
    args.attr_num = feature_length  # set attr_num  = feature_length
    print('args.attr_num:', args.attr_num)
    print('args.entropy_method:', args.entropy_method)
    print(args)

    dataset = load_dataset(args.data_name)
    filename = 'train-data-{}-RL-cand_num-{}-cand_item_num-{}-embed-{}-seq-{}-gcn-{}_moo'.format(
        args.data_name, args.cand_num, args.cand_item_num, args.embed, args.seq, args.gcn)
    train(args, kg, dataset, filename)

if __name__ == '__main__':
    main()