#####NOTE:
#This is an updated controller for the T1D simulator which allows 
#boluses to be delayed and carbs to be missing.
#It augments the simulator found here:
#https://github.com/jxx123/simglucose
#***We are not the authors of this code but made simple modifications only.***

from .base import Controller
from .base import Action
import numpy as np
import pandas as pd
import pkg_resources
import logging
import numpy as np

logger = logging.getLogger(__name__)
CONTROL_QUEST = pkg_resources.resource_filename(
    'simglucose', 'params/Quest.csv')
PATIENT_PARA_FILE = pkg_resources.resource_filename(
    'simglucose', 'params/vpatient_params.csv')

#Updated CR/CF values for controller to function properly.
CRCF={}
CRCF['adolescent#001']=[13.61,49]
CRCF['adolescent#002']=[8.06,29.02]
CRCF['adolescent#003']=[20.62,74.25]
CRCF['adolescent#004']=[14.18,51.06]
CRCF['adolescent#005']=[14.7,52.93]
CRCF['adolescent#006']=[10.08,36.3]
CRCF['adolescent#007']=[11.46,41.25]
CRCF['adolescent#008']=[7.89,28.4]
CRCF['adolescent#009']=[20.77,74.76]
CRCF['adolescent#010']=[15.07,54.26]
CRCF['adult#001']=[9.92,35.7]
CRCF['adult#002']=[8.64,31.1]
CRCF['adult#003']=[8.86,31.9]
CRCF['adult#004']=[14.79,53.24]
CRCF['adult#005']=[7.32,26.35]
CRCF['adult#006']=[8.14,29.32]
CRCF['adult#007']=[11.9,42.85]
CRCF['adult#008']=[11.69,42.08]
CRCF['adult#009']=[7.44,26.78]
CRCF['adult#010']=[7.76,27.93]
CRCF['child#001']=[28.62,103.02]
CRCF['child#002']=[27.51,99.02]
CRCF['child#003']=[31.21,112.35]
CRCF['child#004']=[25.23,90.84]
CRCF['child#005']=[12.21,43.97]
CRCF['child#006']=[24.72,89]
CRCF['child#007']=[13.81,49.71]
CRCF['child#008']=[23.26,83.74]
CRCF['child#009']=[28.75,103.48]
CRCF['child#010']=[24.21,87.16]

class BBController(Controller):
    def __init__(self, addmiss,target=140):
        self.quest = pd.read_csv(CONTROL_QUEST)
        self.patient_params = pd.read_csv(
            PATIENT_PARA_FILE)
        self.target = target
        self.correctionneeded=0
        self.waittime=0
        self.bolscale=1
        self.addmiss=addmiss#added variable which will control whether or not missingness/delays are added

    def policy(self, observation, reward, done, **kwargs):
        sample_time = kwargs.get('sample_time', 1)
        pname = kwargs.get('patient_name')
        meal = kwargs.get('meal2')

        action = self._bb_policy(
            pname,
            meal,
            observation.CGM,
            sample_time)
        return action

    def _bb_policy(self, name, meal, glucose, env_sample_time):
        if any(self.quest.Name.str.match(name)):
            q = self.quest[self.quest.Name.str.match(name)]
            params = self.patient_params[self.patient_params.Name.str.match(
                name)]
            u2ss = np.asscalar(params.u2ss.values)
            BW = np.asscalar(params.BW.values)
        else:
            q = pd.DataFrame([['Average', 1 / 15, 1 / 50, 50, 30]],
                             columns=['Name', 'CR', 'CF', 'TDI', 'Age'])
            u2ss = 1.43
            BW = 57.0

        basal = u2ss * BW / 6000


        if meal > 0:
            logger.info('Calculating bolus ...')
            logger.debug('glucose = {}'.format(glucose))

            #scale meal x3 for good control, otherwise fails
            bolus = np.asscalar(meal*3/ CRCF[name][0] + (glucose > 150)
                            * (glucose - self.target) / CRCF[name][1])
            bolus*=1
            if self.waittime>0:
                bolus+=self.correctionneeded
                self.waittime=0
                self.correctionneeded=0

            else:
                #add delay
                if np.random.randint(6)>=1 and self.addmiss:
                    print('DELAY')
                    self.correctionneeded=bolus
                    self.waittime=np.random.randint(42)
                    bolus=0
                #skip bolus
                if self.addmiss and np.random.uniform()<.2:
                    print('no bolus here.')
                    bolus=0
                    self.correctionneeded=0
                    self.waittime=0
        else:
            bolus = 0
            if self.waittime>0:
                self.waittime-=1
            else:
                if self.correctionneeded>0:
                    bolus=self.correctionneeded
                    self.correctionneeded=0
            

        bolus = bolus / env_sample_time
        action = Action(basal=basal, bolus=bolus*self.bolscale)
        return action

    def reset(self):
        pass
