# -*- coding: utf-8 -*-
import argparse
import copy
import os
import random
import time
from itertools import product

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from Model.AdaO2B import AdaO2B
from Model.EXP3_B import EXP3_B
from Model.BLTS_B import BLTS_B
from Model.SBUCB import SBUCB
from utils.data import (Data_4_AdaO2B, get_candidata_vector, get_online_data,
                        load_base_model, prep_data_4_adao2b, get_best_batch,
                        get_batch_average_reward,
                        get_state_vector_and_reward_for_NIIDC,
                        reservoir_sampling, data_dependent_sampling)
from utils.evaluation import total_average_reward
from utils.simulator import Simulator


class DefaultConfig(object):
    def __init__(self,
                 base_model,
                 base_model_save_name,
                 K,
                 history_data_name,
                 data_index=2):
        self.rec_model = "SAC"  # ["SAC","VC","BBL","LL","CW"]
        self.device = "cuda:0"

        self.data_index = data_index
        self.feature_dimension = 50
        self.N = 40
        self.B = 5000
        self.candidate_size = 100

        # simulator
        self.simulator_path = '../Data/Simulator/'
        self.embedding_size = 16
        self.simulator_online_path = self.simulator_path + '1031_simulator.pth'
        self.user_num = 6890
        self.item_num = 973

        # online data
        self.online_path = "../Data/Online/"
        self.user_data_online_path = self.online_path + "user_feature_online.csv"  # 25
        self.item_data_online_path = self.online_path + "item_daily_feature_online.csv"  # 25
        self.session_data_online_path = self.online_path + "data_" + str(
            self.data_index) + "_online.csv"

        self.base_model = base_model
        self.base_model_save_name = base_model_save_name
        self.K = K
        self.base_model_batch_list = [
            i for i in range(self.N - self.K, self.N)
        ]
        self.history_result_path = "../Result/" + self.base_model + "/"
        self.history_avg_reward_file_name = self.base_model_save_name + "_online_data1"

        self.adao2b_path = '../Data/AdaO2B/'
        self.history_data_path = self.adao2b_path + self.base_model + "_" + history_data_name + "_online_data1_adao2b.csv"

        # result
        self.result_path = "../Result/"
        self.res_name = ""

        # model save and load
        self.rec_model_load_name = ""
        self.rec_model_save_name = ""

        # SBUCB
        self.SBUCB_mu = 1.4

        # BLTS_B
        self.BLTS_B_mu = 0.2

        # EXP3_B
        self.EXP3_B_delta = 0.1

        self.seed_num = 2022


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def bacth_test():
    # load simulator data and model
    online_simulator = Simulator(user_num=opt.user_num,
                                 item_num=opt.item_num,
                                 embedding_size=opt.embedding_size,
                                 simulator_path=opt.simulator_online_path)

    # load online data
    user_feature_online, item_feature_online, session_data_online = get_online_data(
        opt.user_data_online_path, opt.item_data_online_path,
        opt.session_data_online_path)
    uid_online = session_data_online.user_id.values
    date_online = session_data_online.date.values
    candidate_itemid_online = session_data_online.candidate.values

    if opt.base_model == "SBUCB":
        base_model = SBUCB(mu=opt.SBUCB_mu, feature_d=opt.feature_dimension)
    elif opt.base_model == "BLTS_B":
        base_model = BLTS_B(mu=opt.BLTS_B_mu, feature_d=opt.feature_dimension)
    elif opt.base_model == "EXP3_B":
        base_model = EXP3_B(N=opt.N,
                            B=opt.B,
                            C=opt.candidate_size,
                            delta=opt.EXP3_B_delta,
                            feature_d=opt.feature_dimension)

    if opt.rec_model == "BBL":
        best_batch_index = get_best_batch(opt.history_result_path,
                                          opt.history_avg_reward_file_name,
                                          opt.B)
        base_model.load_model(opt.base_model_save_name + "_batch" +
                              str(best_batch_index))
        rec_model = copy.deepcopy(base_model)
    elif opt.rec_model == "LL":
        base_model.load_model(opt.base_model_save_name + "_batch" +
                              str(opt.N - 1))
        rec_model = copy.deepcopy(base_model)
    elif opt.rec_model == "NIIDC":
        base_model.load_model(opt.base_model_save_name + "_batch" +
                              str(opt.N - 1))
        # update with reward=1 data
        history_state_vector, history_reward = get_state_vector_and_reward_for_NIIDC(
            user_feature_online,
            item_feature_online,
            opt.history_data_path,
            opt.base_model_batch_list,
            reward=1)
        index = [i for i in range(len(history_state_vector))]
        random.shuffle(index)
        base_model.update(
            history_state_vector[index].reshape(-1, opt.feature_dimension),
            history_reward[index].reshape(-1, 1))
        rec_model = copy.deepcopy(base_model)
    else:
        base_model_list = load_base_model(base_model, opt.base_model_save_name,
                                          opt.base_model_batch_list)
        if opt.rec_model == "CW":
            base_model_batch_avg_rw = get_batch_average_reward(
                opt.history_result_path, opt.history_avg_reward_file_name,
                opt.base_model_batch_list, opt.B)
            # constant_weight = np.exp(base_model_batch_avg_rw) / np.exp(
            #     base_model_batch_avg_rw).sum()  # n_base_model
            constant_weight = base_model_batch_avg_rw
            print("constant_weight:", constant_weight)

    # save result for stastics
    full_user_feedback = []

    t_recommend = 0
    recommend_time = []

    t1 = time.time()
    for n in tqdm(range(opt.N)):
        u_n = []
        date_n = []
        s_n = []
        user_feedback_n = []
        rewards_n = []

        for b in range(opt.B):
            # Observe the set of candidate items S_𝑛𝑏
            candidate_nb = candidate_itemid_online[n * opt.B + b]
            u_nb = uid_online[n * opt.B + b]
            date_nb = date_online[n * opt.B + b]
            u_n.append(u_nb)
            date_n.append(date_nb)
            candidate_state_vector_nb = get_candidata_vector(
                user_feature_online, item_feature_online, u_nb, candidate_nb,
                date_nb)

            # Recommend item 𝒔 ∈ S_𝑛𝑏 to the user
            t_recommend_start = time.time()

            # SAC
            if opt.rec_model == "SAC":
                base_rec_score_nb = []
                for tmp_base_model in base_model_list:
                    base_rec_score_nb.append(
                        tmp_base_model.get_rec_score(
                            candidate_state_vector_nb))
                base_rec_score_nb = np.array(
                    base_rec_score_nb).T  # n_candidate, n_base_model
                sac_rec_score_nb = base_rec_score_nb.sum(axis=1)  # n_candidate
                s_nb_index = np.argmax(sac_rec_score_nb)
            elif opt.rec_model == "VC":
                base_rec_index_nb = []
                for tmp_base_model in base_model_list:
                    base_rec_index_nb.append(
                        tmp_base_model.bt_recommend(candidate_state_vector_nb))
                s_nb_index = max(set(base_rec_index_nb),
                                 key=base_rec_index_nb.count)
            elif opt.rec_model == "CW":
                base_rec_score_nb = []
                for tmp_base_model in base_model_list:
                    base_rec_score_nb.append(
                        tmp_base_model.get_rec_score(
                            candidate_state_vector_nb))
                base_rec_score_nb = np.array(
                    base_rec_score_nb).T  # n_candidate, n_base_model
                cw_rec_score = np.dot(base_rec_score_nb,
                                      constant_weight)  # n_candidate
                s_nb_index = np.argmax(cw_rec_score)
            elif opt.rec_model == "BBL" or opt.rec_model == "LL" or opt.rec_model == "NIIDC":
                s_nb_index = rec_model.bt_recommend(candidate_state_vector_nb)

            # print("s_nb_index", s_nb_index)
            t_recommend_end = time.time()
            t_recommend += t_recommend_end - t_recommend_start
            recommend_time.append(t_recommend_end - t_recommend_start)

            s_nb = candidate_nb[s_nb_index]
            s_n.append(s_nb)

        # Receive batch user feedback
        user_feedback_n = online_simulator.get_batch_feedback(u_n, s_n)
        rewards_n = copy.deepcopy(user_feedback_n)

        # save batch result
        full_user_feedback.append(user_feedback_n)

    t2 = time.time()
    print("===============Batch Test===============")
    print("rec model: ", opt.rec_model)
    print("dataset index: ", opt.data_index)
    print("total time cost: %fs" % (t2 - t1))
    print("total recommenda time cost: %fs" % (t_recommend))

    np.savetxt(opt.result_path + opt.base_model + "/" + opt.rec_model + "_" +
               opt.res_name + "_batch_data" + str(opt.data_index) +
               "_full_user_feedback.txt",
               np.array(full_user_feedback),
               fmt="%d")
    np.savetxt(
        opt.result_path + opt.base_model + "/" + opt.rec_model + "_" +
        opt.res_name + "_batch_data" + str(opt.data_index) +
        "_recommend_time.txt", np.array(recommend_time))
    print("*****************total average reward: %4f" %
          (total_average_reward(np.array(full_user_feedback))))
    print("***********************end************************")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Demo of argparse")
    parser.add_argument('--model',
                        default='SAC',
                        choices=["SAC", "VC", "BBL", "LL", "CW", "NIIDC"])
    parser.add_argument('--K',
                        default=10,
                        type=int,
                        help='number of base model and data buffers')
    parser.add_argument('--base_model',
                        default='SBUCB',
                        choices=["SBUCB", "EXP3_B", "BLTS_B"])
    parser.add_argument('--ckpt_load_name', default='')
    parser.add_argument('--ckpt_save_name', default='')
    parser.add_argument('--data_index', type=int, default=2, choices=[1, 2])
    parser.add_argument('--res_name', default='', help="res_name")
    parser.add_argument('--base_model_save_name', help="base_model_save_name")
    parser.add_argument('--data_selection',
                        help="data_selection",
                        choices=['S', 'R', 'D'],
                        default='S')
    parser.add_argument('--history_data_name',
                        help="history_data_name",
                        default="")
    parser.add_argument('--seed', type=int, default=2023)

    args = parser.parse_args()
    opt = DefaultConfig(args.base_model, args.base_model_save_name, args.K,
                        args.history_data_name, args.data_index)
    opt.rec_model = args.model
    opt.rec_model_load_name = args.ckpt_load_name
    opt.rec_model_save_name = args.ckpt_save_name
    opt.res_name = args.res_name
    opt.data_selection = args.data_selection
    opt.seed_num = args.seed

    setup_seed(opt.seed_num)

    if opt.data_selection == 'S':
        opt.base_model_batch_list = [i for i in range(opt.N - opt.K, opt.N)]
    elif opt.data_selection == 'R':
        # np.random.seed(2023)
        opt.base_model_batch_list = reservoir_sampling(opt.N, opt.K)
    elif opt.data_selection == 'D':
        opt.base_model_batch_list = data_dependent_sampling(
            opt.history_result_path, opt.history_avg_reward_file_name, opt.B,
            opt.K)

    print('\n'.join(['%s:%s' % item for item in opt.__dict__.items()]))

    # batch_test
    bacth_test()
