import math
import pandas as pd
import numpy as np

class Importance_Sampling(object):
    def __init__(self, raw_data, theta, gamma, policy, step_offset):
        self.raw_data = raw_data
        self.theta = theta
        self.gamma = gamma
        self.traces = []
        self.n_action = 0
        self.n_user = 0
        self.random_prob = 0
        self.policy = policy
        self.alpha = 0.5
        self.step_offset = step_offset # from a step h to calculate ips

    def readData(self):
        raw_data = self.raw_data

        Q_list = ['ps', 'fwe', 'we']
        beh_prob_list = ['prob_ps', 'prob_fwe', 'prob_we']
        user_list = list(raw_data['userID'].unique())
        self.n_action = len(Q_list)
        self.n_user = len(user_list)
        self.random_prob = 1.0 / self.n_action
        

        for user in user_list:
            user_sequence = []
            user_data = raw_data.loc[raw_data['userID'] == user,]
            row_index = user_data.index.tolist()
            
            expert_count = 0
            for i in range(0, len(row_index)):
                action = user_data.loc[row_index[i], 'real_action']
                
                reward = user_data.loc[row_index[i], 'inferred_rew']
                # critical = user_data.loc[row_index[i], 'critical']
                Qs = user_data.loc[row_index[i], Q_list].tolist()
                beh_probs = user_data.loc[row_index[i], beh_prob_list].tolist()
                
                eva_probs = []
                
                lif self.policy == 'SOCHRL':
                    if user_data.loc[row_index[i], 'critical'] == 1:
                        eva_action = Qs.index(min(Qs))
                        eva_probs = [0.8 if x == eva_action else 1e-1 for x in range(self.n_action)]
                    else:
                        eva_probs = [1/self.n_action for x in range(self.n_action)]
                elif self.policy == 'FHRL':
                    eva_action = Qs.index(max(Qs))
                    eva_probs = [0.8 if x == eva_action else 1e-1 for x in range(self.n_action)]
                elif self.policy == 'expert':
                    eva_probs = [1/self.n_action for x in range(self.n_action)]
              
                  

                user_sequence.append((action, reward, Qs, beh_probs, eva_probs))

            self.traces.append(user_sequence)


    def IS(self):
        IS = 0

        for each_student_data in self.traces:
            cumul_policy_prob = 1
            cumul_random_prob = 1
            cumulative_reward = 0

            for i, (action, reward, Qs, beh_probs, eva_probs) in enumerate(each_student_data):   

#                 print(i)
#                 print((action, reward, Qs, beh_probs, eva_probs))
                cumul_policy_prob *= eva_probs[action]
                cumul_random_prob *= beh_probs[action]
                cumulative_reward += math.pow(self.gamma, i+self.step_offset) * reward

            weight = cumul_policy_prob / cumul_random_prob
            
            IS_reward = cumulative_reward * weight

            IS += IS_reward

        IS = float(IS) / self.n_user
        return IS


    def WIS(self):
        WIS = 0
        total_weight = 0

        for each_student_data in self.traces:
            cumul_policy_prob = 1
            cumul_random_prob = 1
            cumulative_reward = 0

            for i, (action, reward, Qs, beh_probs, eva_probs) in enumerate(each_student_data):

                
                cumul_policy_prob *= eva_probs[action]
                cumul_random_prob *= beh_probs[action]
                cumulative_reward += math.pow(self.gamma, i+self.step_offset) * reward

            weight = cumul_policy_prob / cumul_random_prob
            
            total_weight += weight
            IS_reward = cumulative_reward * weight

            WIS += IS_reward

        WIS = float(WIS) / total_weight
        return WIS