import numpy as np
import torch


class environment():

    def __init__(self):
        self.W_put = 1.
        self.W_call = 1.5

        self.W = 0

        self.f_u = 9/8
        self.f_d = 8/9

        self.prob = 0.45
        self.state = np.zeros(2)
        self.position = True

        self.state[0] = 1.
        self.state[1] = 0.

    def reset(self):
        self.state = np.zeros(2)
        self.state[0] = 1.
        self.state[1] = 0.
        self.position = True

        return self.state

    def step(self, action):

        if (action == 1) and (self.position is True):
            spot_return = max([
                0., self.W_put - self.state[0]]) + max([0., self.state[0] - self.W_call])
            #spot_return = max(0., self.state[0] - self.W_call)
            self.position = False
        else:
            spot_return = 0.

        if np.random.uniform() < self.prob:
            self.state[0] = self.state[0]*self.f_u
        else:
            self.state[0] = self.state[0]*self.f_d

        self.state[1] += 1.

        if self.state[1] == 19:
            spot_return = max([0., self.W_put - self.state[0]]) + \
                max([0., self.state[0] - self.W_call])

        #spot_return = 1.
        return self.state, spot_return
