import numpy as np


class environment():

    def __init__(self):
        self.set_interest_rate_non_liquid = [1.1, 2]
        self.interest_rate_liquid = 1.001

        self.prob_switch = 0.9
        self.non_liquid_state = 1
        self.prob_default = 0.05

        self.W = 4

        self.eta = 0.2

        self.cash = 100000

        self.state = np.zeros(self.W+2)
        self.state[0] = 1

        self.portfolio = [0., 1.]

        self.log_position = []

    def reset(self):
        self.state = np.zeros(self.W+2)
        self.state[0] = 1
        self.non_liquid_state = 1

        self.portfolio = [0., 1.]

        self.log_position = []

        return self.state

    def step(self, action):
        if self.W in self.log_position:
            temp0 = self.prob_switch * \
                (self.set_interest_rate_non_liquid[self.non_liquid_state])
            temp1 = (1-self.prob_switch) * \
                (self.set_interest_rate_non_liquid[1-self.non_liquid_state])

            expected_non_liquid_reward = (1-self.prob_default)*(temp0 + temp1)

        else:
            expected_non_liquid_reward = 0

        spot_return = 0

        for j in range(len(self.log_position)):
            self.log_position[j] += 1

        index_switch = np.random.choice(
            [1, 0], p=[self.prob_switch, 1-self.prob_switch])

        if index_switch == 0:
            if self.non_liquid_state == 1:
                self.non_liquid_state = 0
            else:
                self.non_liquid_state = 1

        interest_rate_non_liquid = self.set_interest_rate_non_liquid[
            self.non_liquid_state]

        spot_return += self.portfolio[1]*self.interest_rate_liquid

        if len(self.log_position) > 0:
            default_asset_idx = np.random.choice([0, 1], size=len(
                self.log_position), p=[1-self.prob_default, self.prob_default])

            default_asset = []

            for j in range(len(default_asset_idx)):
                # print(default_asset_idx[j])
                if default_asset_idx[j] == 1:
                    default_asset.append(self.log_position[j])

            for j in range(len(default_asset)):
                self.log_position.remove(default_asset[j])

                self.portfolio[0] -= self.eta
                self.portfolio[1] += self.eta

        if self.W in self.log_position:
            spot_return += interest_rate_non_liquid*self.portfolio[0]
            self.state[-1] = interest_rate_non_liquid - \
                expected_non_liquid_reward

            self.portfolio[0] -= self.eta
            self.portfolio[1] += self.eta
            self.log_position.remove(self.W)
        else:
            self.state[-1] = -expected_non_liquid_reward

        if (action == 1) and self.portfolio[0] < 1-self.eta:
            self.portfolio[0] += self.eta
            self.portfolio[1] -= self.eta

            self.log_position.append(0)

        else:
            self.state[-1] = 0

        for j in range(self.W):
            self.state[-2-j] = self.state[-2-j-1]

        self.state[0] = self.portfolio[1]

        return self.state, spot_return - 1
