# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import torch
import argparse
import random
import copy
import os
import time
from scipy.special import expit
from tqdm import tqdm
from Model.SBUCB import SBUCB
from Model.EXP3_B import EXP3_B
from Model.BLTS_B import BLTS_B
from utils.simulator import get_batch_feedback
from itertools import product

from utils.data import (Data_4_AdaO2B, get_candidata_vector, load_base_model,
                        prep_data_4_adao2b, get_best_batch,
                        get_batch_average_reward,
                        get_state_vector_and_reward_for_NIIDC,
                        reservoir_sampling, data_dependent_sampling)
from utils.evaluation import total_average_reward


class DefaultConfig(object):
    def __init__(self,
                 base_model,
                 base_model_save_name,
                 K,
                 history_data_name,
                 data_index=2):
        self.rec_model = "SAC"  # ["SAC","VC","BBL","LL","CW"]
        self.device = "cuda:1"

        self.data_index = data_index
        self.feature_dimension = 10
        self.N = 40
        self.B = 5000
        self.candidate_size = 10

        # simulator
        self.w = np.random.normal(0.1, 0.01,
                                  (self.feature_dimension, 1))  # reward weight

        self.base_model = base_model
        self.base_model_save_name = base_model_save_name
        self.K = K
        self.base_model_batch_list = [
            i for i in range(self.N - self.K, self.N)
        ]
        self.history_result_path = "../Result/" + self.base_model + "/"
        self.history_avg_reward_file_name = self.base_model_save_name + "_online_data1"

        self.adao2b_path = '../Data/AdaO2B/'
        self.history_data_path = self.adao2b_path + self.base_model + "_" + history_data_name + "_online_data1_adao2b.csv"

        # result
        self.result_path = "../Result/"
        self.res_name = ""

        # model save and load
        self.rec_model_load_name = ""
        self.rec_model_save_name = ""

        # SBUCB
        self.SBUCB_mu = 0.4

        # BLTS_B
        self.BLTS_B_mu = 0.8

        # EXP3_B
        self.EXP3_B_delta = 0.1

        self.seed_num = 2022


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def bacth_test():
    if opt.base_model == "SBUCB":
        base_model = SBUCB(mu=opt.SBUCB_mu, feature_d=opt.feature_dimension)
    elif opt.base_model == "BLTS_B":
        base_model = BLTS_B(mu=opt.BLTS_B_mu, feature_d=opt.feature_dimension)
    elif opt.base_model == "EXP3_B":
        base_model = EXP3_B(N=opt.N,
                            B=opt.B,
                            C=opt.candidate_size,
                            delta=opt.EXP3_B_delta,
                            feature_d=opt.feature_dimension)

    if opt.rec_model == "BBL":
        best_batch_index = get_best_batch(opt.history_result_path,
                                          opt.history_avg_reward_file_name,
                                          opt.B)
        print("best batch index: ", best_batch_index)
        base_model.load_model(opt.base_model_save_name + "_batch" +
                              str(best_batch_index - 1))
        rec_model = copy.deepcopy(base_model)
    elif opt.rec_model == "LL":
        base_model.load_model(opt.base_model_save_name)
        rec_model = copy.deepcopy(base_model)
    elif opt.rec_model == "NIIDC":
        base_model.load_model(opt.base_model_save_name)
        for _ in range(1):
            # update with reward=1 data
            history_state_vector, history_reward = get_state_vector_and_reward_for_NIIDC(
                opt.history_data_path, opt.base_model_batch_list, reward=1)
            index = [i for i in range(len(history_state_vector))]
            random.shuffle(index)
            base_model.update(
                history_state_vector[index].reshape(-1, opt.feature_dimension),
                history_reward[index].reshape(-1, 1))
        rec_model = copy.deepcopy(base_model)
    else:
        base_model_list = load_base_model(base_model, opt.base_model_save_name,
                                          opt.base_model_batch_list)
        if opt.rec_model == "CW":
            base_model_batch_avg_rw = get_batch_average_reward(
                opt.history_result_path, opt.history_avg_reward_file_name,
                opt.base_model_batch_list, opt.B)
            # 这儿也可以考虑其他的加权方式
            # constant_weight = np.exp(base_model_batch_avg_rw) / np.exp(
            #     base_model_batch_avg_rw).sum()  # n_base_model
            constant_weight = base_model_batch_avg_rw
            print("constant_weight:", constant_weight)

    # save result for stastics
    full_user_feedback = []
    t_recommend = 0
    recommend_time = []

    t1 = time.time()
    for n in tqdm(range(opt.N)):

        rec_state_n = []
        user_feedback_n = []
        rewards_n = []

        for b in range(opt.B):
            # Observe the set of candidate items S_𝑛𝑏
            candidate_state_vector_nb = get_candidata_vector(
                opt.data_index, opt.candidate_size, opt.feature_dimension)

            # Recommend item 𝒔 ∈ S_𝑛𝑏 to the user
            t_recommend_start = time.time()

            # SAC
            if opt.rec_model == "SAC":
                base_rec_score_nb = []
                for tmp_base_model in base_model_list:
                    base_rec_score_nb.append(
                        tmp_base_model.get_rec_score(
                            candidate_state_vector_nb))
                base_rec_score_nb = np.array(
                    base_rec_score_nb).T  # n_candidate, n_base_model
                sac_rec_score_nb = base_rec_score_nb.sum(axis=1)  # n_candidate
                rec_state_nb_index = np.argmax(sac_rec_score_nb)
            elif opt.rec_model == "VC":
                base_rec_index_nb = []
                for tmp_base_model in base_model_list:
                    base_rec_index_nb.append(
                        tmp_base_model.bt_recommend(candidate_state_vector_nb))
                rec_state_nb_index = max(set(base_rec_index_nb),
                                         key=base_rec_index_nb.count)
            elif opt.rec_model == "CW":
                base_rec_score_nb = []
                for tmp_base_model in base_model_list:
                    base_rec_score_nb.append(
                        tmp_base_model.get_rec_score(
                            candidate_state_vector_nb))
                base_rec_score_nb = np.array(
                    base_rec_score_nb).T  # n_candidate, n_base_model
                cw_rec_score = np.dot(base_rec_score_nb,
                                      constant_weight)  # n_candidate
                rec_state_nb_index = np.argmax(cw_rec_score)
            elif opt.rec_model == "BBL" or opt.rec_model == "LL" or opt.rec_model == "NIIDC":
                rec_state_nb_index = rec_model.bt_recommend(
                    candidate_state_vector_nb)

            # print("rec_state_nb_index", rec_state_nb_index)
            t_recommend_end = time.time()
            t_recommend += t_recommend_end - t_recommend_start
            recommend_time.append(t_recommend_end - t_recommend_start)

            rec_state_nb = candidate_state_vector_nb[rec_state_nb_index]
            rec_state_n.append(rec_state_nb)

        # Receive batch user feedback
        user_feedback_n = get_batch_feedback(
            np.array(rec_state_n).reshape(opt.B, -1), opt.w)
        rewards_n = copy.deepcopy(user_feedback_n)

        # save batch result
        full_user_feedback.append(user_feedback_n)

    t2 = time.time()
    print("===============Batch Test===============")
    print("rec model: ", opt.rec_model)
    print("dataset index: ", opt.data_index)
    print("total time cost: %fs" % (t2 - t1))
    print("total recommenda time cost: %fs" % (t_recommend))

    np.savetxt(opt.result_path + opt.base_model + "/" + opt.rec_model + "_" +
               opt.res_name + "_batch_data" + str(opt.data_index) +
               "_full_user_feedback.txt",
               np.array(full_user_feedback),
               fmt="%d")
    np.savetxt(
        opt.result_path + opt.base_model + "/" + opt.rec_model + "_" +
        opt.res_name + "_batch_data" + str(opt.data_index) +
        "_recommend_time.txt", np.array(recommend_time))
    print("*****************total average reward: %4f" %
          (total_average_reward(np.array(full_user_feedback))))
    print("***********************end************************")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Demo of argparse")
    parser.add_argument('--model',
                        default='SAC',
                        choices=["SAC", "VC", "BBL", "LL", "CW", "NIIDC"])
    parser.add_argument('--K',
                        default=10,
                        type=int,
                        help='number of base model and data buffers')
    parser.add_argument('--base_model',
                        default='SBUCB',
                        choices=["SBUCB", "EXP3_B", "BLTS_B"])
    parser.add_argument('--ckpt_load_name', default='')
    parser.add_argument('--ckpt_save_name', default='')
    parser.add_argument('--data_index', type=int, default=2, choices=[1, 2])
    parser.add_argument('--res_name', default='', help="res_name")
    parser.add_argument('--base_model_save_name', help="base_model_save_name")
    parser.add_argument('--data_selection',
                        help="data_selection",
                        choices=['S', 'R', 'D'],
                        default='S')
    parser.add_argument('--history_data_name',
                        help="history_data_name",
                        default="")
    parser.add_argument('--seed', type=int, default=2022)

    args = parser.parse_args()
    opt = DefaultConfig(args.base_model, args.base_model_save_name, args.K,
                        args.history_data_name, args.data_index)
    opt.rec_model = args.model
    opt.rec_model_load_name = args.ckpt_load_name
    opt.rec_model_save_name = args.ckpt_save_name
    opt.res_name = args.res_name
    opt.data_selection = args.data_selection
    opt.seed_num = args.seed

    setup_seed(opt.seed_num)

    if opt.data_selection == 'S':
        opt.base_model_batch_list = [i for i in range(opt.N - opt.K, opt.N)]
    elif opt.data_selection == 'R':
        opt.base_model_batch_list = reservoir_sampling(opt.N, opt.K)
    elif opt.data_selection == 'D':
        opt.base_model_batch_list = data_dependent_sampling(
            opt.history_result_path, opt.history_avg_reward_file_name, opt.B,
            opt.K)

    print('\n'.join(['%s:%s' % item for item in opt.__dict__.items()]))

    # batch_test
    bacth_test()
