# %%
import numpy as np
import pandas as pd
from scipy.stats import binom
from scipy.stats import multivariate_normal
import math


import CoxVB



# simulation class
class BayesCoxCP():
    def __init__(self, ini, deg_gam, deg_gam_2, pmin, pmax, ngrid, gen_x, dim, beta0, cdf0, cdf_true, TH, nrepl, tuning, save_dir, name_flag='', a0l_tune=10^5, s = 1):
        self.ini = ini      # length of the first epoch (hyperpara)
        self.deg_gam = deg_gam  # degree of exploration (hyperpara)
        self.deg_gam_2 = deg_gam_2  # degree of exploration (hyperpara)
        self.pmin = pmin
        self.pmax = pmax
        self.ngrid = ngrid  # number of the grid of price
        self.gen_x = gen_x  # data generate process of covariates
        self.dim = dim      # dimension of covariates
        self.beta0 = beta0  # true regression cofficient
        self.cdf0 = cdf0    # true baseline cdf
        self.cdf_true = cdf_true
        self.TH = TH        # time horizon
        self.nrepl = nrepl  # number of replication
        self.tuning = tuning    # if tuning==False, save results
        self.save_dir = save_dir
        self.name_flag = name_flag
        self.a0l_tune = a0l_tune
        self.unk_cumRev_real = 0
        self.s = s          # degree of isplines

        # number of inner points = ngrid
        self.pseq = np.linspace(self.pmin, self.pmax, self.ngrid+2)[1:-1]
        # price grid including pmin, pmax
        self.pseq_total = np.linspace(self.pmin, self.pmax, self.ngrid+2)
        # column names
        self.unk_X_colname = [f'X{X_num}' for X_num in range(1, self.dim + 1)]

        # regret mean stack
        self.unk_reg_iter_stack = np.full((self.nrepl, self.TH), np.nan)

        # oers : opt_expected_reward_stack
        # uoers : unk_expected_reward_stack
        # urrs : unk_real_reward_stack
        self.oers = np.full((self.nrepl, self.TH), np.nan)
        self.uoers = np.full((self.nrepl, self.TH), np.nan)
        self.urrs = np.full((self.nrepl, self.TH), np.nan)

    def fitting(self):
        for iter_ in range(self.nrepl):
            print(f'Current iteration: {iter_ + 1}')

            # Covariates generate
            X = self.gen_x(self.TH, self.dim, seed=None)

            # define episodes
            epis, epi = self.define_episode()

            # stack variable for 1 iter
            opt_p_stack = np.full(self.TH, np.nan)
            opt_rew_stack = np.full(self.TH, np.nan)
            unk_p_stack = np.full(self.TH, np.nan)
            unk_y_stack = np.full(self.TH, np.nan)
            unk_rew_stack = np.full(self.TH, np.nan)
            unk_rewreal_stack = np.full(self.TH, np.nan)

            # beta estimate stack for 1 iter
            unk_beta_stack = np.zeros((epi, self.dim))
            # survival function stack for 1 iter
            unk_surv_stack = np.zeros((epi, self.ngrid))

            cur_end = 0
            cur_start = 1
            for epis_curr in range(1, epi+1):
                epis_num = np.sum(epis == epis_curr)
                print(f'epis_curr: {epis_curr}, epis_num : {epis_num}')
                
                if epis_curr != 1:
                    temp_b_unk = min(self.deg_gam * min(self.deg_gam_2 * math.sqrt(self.ngrid / (2**(epis_curr-1))), 2**(-(epis_curr-1)/3)) * math.sqrt((epis_curr-1)*math.log(2)), 1)
                    #print(f'temp_b_unk: {temp_b_unk}')

                # get price
                for i in range(epis_num):
                    # observe new data
                    newdat = X[i + cur_end]
                    if epis_curr == 1:      # first episode
                        # random price
                        unk_p_stack[i + cur_end] = np.random.choice(self.pseq, 1, replace=True).item()
                        unk_y_stack[i + cur_end] = binom.rvs(1, 1 - self.cdf_true(self.cdf0, unk_p_stack[i + cur_end], np.sum(newdat * self.beta0)))
                    else:
                        b_unk = binom.rvs(1, 1 - temp_b_unk)

                        if b_unk == 0:
                            # random price
                            unk_p_stack[i + cur_end] = np.random.choice(self.pseq, 1, replace=True).item()
                            unk_y_stack[i + cur_end] = binom.rvs(1, 1 - self.cdf_true(self.cdf0, unk_p_stack[i + cur_end], np.sum(newdat * self.beta0)))
                        else:
                            # calculate expected revenue
                            unk_expected_rewards = self.pseq * (unk_surv_stack[epis_curr - 2]**(np.exp(np.dot(newdat, unk_beta_stack[epis_curr - 2]))))
                            # offer price
                            unk_p_stack[i + cur_end] = self.pseq[np.argmax(unk_expected_rewards)]
                            # get reward
                            unk_y_stack[i + cur_end] = binom.rvs(1, 1 - self.cdf_true(self.cdf0, unk_p_stack[i + cur_end], np.sum(newdat * self.beta0)))

                    # optimal price
                    opt_expected_rewards = self.pseq * self.func_survival(newdat)
                    opt_p_stack[i + cur_end] = self.pseq[np.argmax(opt_expected_rewards)]

                    # expected reward
                    unk_rewreal_stack[i + cur_end] = unk_p_stack[i + cur_end] * unk_y_stack[i + cur_end]
                    opt_rew_stack[i + cur_end] = opt_p_stack[i + cur_end] * (1 - self.cdf_true(self.cdf0, opt_p_stack[i + cur_end], np.sum(newdat * self.beta0)))
                    unk_rew_stack[i + cur_end] = unk_p_stack[i + cur_end] * (1 - self.cdf_true(self.cdf0, unk_p_stack[i + cur_end], np.sum(newdat * self.beta0)))
                
                cur_end = cur_end + epis_num

                # estimate model and update
                ## Cox VB initialize
                X_data = X[cur_start:cur_end]
                L_data = unk_p_stack[cur_start:cur_end]
                delta3 = unk_y_stack[cur_start:cur_end]
                delta1 = 1 - delta3
                ## hyperparameter
                a0l = np.ones(self.s + self.ngrid) / self.a0l_tune
                rho0l = np.ones(self.s + self.ngrid) * self.ngrid / self.a0l_tune
                mu0 = np.zeros(self.dim)
                alpha0 = 1
                a_trunc = -10
                b_trunc = 10
                
                # modeling
                coxvb_model = CoxVB.CoxVB_trunc_post(X_data, L_data, delta1, delta3, self.s, self.ngrid, self.pseq_total, a0l, rho0l, mu0, alpha0)

                # estimate
                coxvb_model.coordinate_ascent(iterate=50, tol=1e-3)
                # update
                unk_beta_temp, _, unk_surv_stack[epis_curr-1] = coxvb_model.get_estimation(self.pseq)
                unk_Var_beta_temp = coxvb_model.get_Var_beta()
                # truncate
                truncated_samples = self.truncated_normal_sample(unk_beta_temp, unk_Var_beta_temp, a_trunc, b_trunc, 10000)
                unk_beta_stack[epis_curr-1], _ = self.estimate_moments(truncated_samples)

                # set cur_start
                cur_start = cur_start + epis_num
            # end episode
            self.unk_reg_iter_stack[iter_] = opt_rew_stack - unk_rew_stack
            self.oers[iter_] = opt_rew_stack
            self.uoers[iter_] = unk_rew_stack
            self.urrs[iter_] = unk_rewreal_stack

            #print(f'unknown model estimate: (beta) {unk_beta_stack}, (surv) {unk_surv_stack}')

        if not self.tuning:
            # save as csv file
            pd.DataFrame(self.unk_reg_iter_stack).to_csv(f'{self.save_dir}/cumRev_unk_reg_coxph_{self.name_flag}.csv', index=False)
            pd.DataFrame(self.oers).to_csv(f'{self.save_dir}/cumRev_unk_opt_coxph_{self.name_flag}.csv', index=False)
            pd.DataFrame(self.uoers).to_csv(f'{self.save_dir}/cumRev_unk_coxph_{self.name_flag}.csv', index=False)
        else:
            unk_rr_mean = np.mean(np.cumsum(self.urrs, axis=1), axis=0)
            self.unk_cumRev_real = unk_rr_mean[-1]


    # survival function
    def func_survival(self, x_t):
        return 1 - self.cdf_true(self.cdf0, self.pseq, np.sum(x_t * self.beta0))

    # min_v function (avoid log 0)
    def min_v(self, x):
        return min(x, 1 - 1e-8)

    # cumulative hazard function
    def func_hazard(self, p_t):
        return - np.log(1 - np.vectorize(self.min_v)(self.cdf0(p_t)))
    
    # define episode
    def define_episode(self):
        start_num = self.ini
        tmp_jump = start_num * (2**np.arange(11))
        epi_doubling = [start_num]
        for i in range(1, len(tmp_jump)):
            epi_doubling.append(tmp_jump[i] + epi_doubling[i-1])
        epi_doubling = [x for x in epi_doubling if x < self.TH]

        epis = np.zeros(self.TH, dtype=int)
        epis[np.array(epi_doubling) + 1] = 1
        epis = np.cumsum(epis) + 1
        epi = epis[-1]

        return epis, epi
    
    # get real reward stack (tuning results)
    def get_unk_cum_real(self):
        return self.unk_cumRev_real
    
    def truncated_normal_sample(self, mean, cov, lower, upper, n_samples):
        samples = multivariate_normal.rvs(mean=mean, cov=cov, size=n_samples)
        mask = np.all((samples >= lower) & (samples <= upper), axis=1)
        truncated_samples = samples[mask]
        return truncated_samples

    def estimate_moments(self, truncated_samples):
        mean_est = np.mean(truncated_samples, axis=0)
        cov_est = np.cov(truncated_samples, rowvar=False)
        return mean_est, cov_est