'''
implementation of the REINFORCE algorithm 
with episodic mini-batches.
This REINFORCE algorithm is a slightly modified version from:
https://towardsdatascience.com/learning-reinforcement-learning-reinforce-with-pytorch-5e8ad7fc7da0
Please refer to the above link for description.
'''

import os
import gym
import sys
import time
import torch
import random
import numpy as np
import pandas as pd 
from torch import nn
from torch import optim
#from torchviz import make_dot
import matplotlib.pyplot as plt
import torch.nn.functional as F

class reinforceFcnn(nn.Module):
    '''Fully-connected neural network for REINFORCE to train it. 
    There are two different NNs: one for size-aware and deadline scheduling.
    The other NN is for the recovering bandits problem.'''
    def __init__(self):
        super(reinforceFcnn, self).__init__()

        # this nn is specific to recovering bandits
        '''
        self.nInputs = 4  # 4 for recovering, 8 for deadline and size-aware cases
        self.nOutputs = 4        
        self.linear1 = nn.Linear(self.nInputs, self.nInputs*6) # input layer (state size, output neurons)
        self.linear2 = nn.Linear(self.nInputs*6, self.nInputs*6)
        self.linear3 = nn.Linear(self.nInputs*6, self.nInputs*6)
        self.linear4 = nn.Linear(self.nInputs*6, self.nInputs*6)
        self.linear5 = nn.Linear(self.nInputs*6, self.nInputs*5)
        self.linear6 = nn.Linear(self.nInputs*5, self.nOutputs)
        '''
        # This NN is for size-aware and deadline problems
        self.nInputs = 8  # 4 for recovering, 8 for deadline and size-aware cases
        self.nOutputs = 4        
        self.linear1 = nn.Linear(self.nInputs, self.nInputs*4) # input layer (state size, output neurons)
        self.linear2 = nn.Linear(self.nInputs*4, self.nInputs*4)
        self.linear3 = nn.Linear(self.nInputs*4, self.nInputs*4)
        self.linear4 = nn.Linear(self.nInputs*4, self.nOutputs)
        
        #self.printNumParams()

    def forward(self, x):
        # feedforward pass

        # for recovering bandits
        '''
        x = torch.FloatTensor(x)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        x = F.relu(self.linear5(x))
        x = self.linear6(x)
        x = F.softmax(x, dim=-1)
        '''
        # for size-aware and deadline
        x = torch.FloatTensor(x)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = self.linear4(x)
        x = F.softmax(x, dim=-1)
        
        return x

    def printNumParams(self): # prints total number of parameters in the neural network
        total_params = sum(p.numel() for p in self.parameters())
        total_params_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f'Total number of parameters: {total_params}')
        print(f'Total number of trainable parameters: {total_params_trainable}')


class REINFORCE(object):

    def __init__(self, lr, env, seed, numEpisodes, batchSize, discountFactor, saveDir, activateArms, episodeSaveInterval):

        #-------------constants-------------
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        self.seed = seed
        self.numEpisodes = numEpisodes
        self.episodeRanges = np.arange(0, self.numEpisodes+episodeSaveInterval, episodeSaveInterval) # save trained model every N episodes
        self.batchSize = batchSize 
        self.beta = discountFactor
        self.env = env
        self.directory = saveDir
        self.nn = reinforceFcnn() #policy_estimator(self.env)
        self.LearningRate = lr  
        self.activateArms = activateArms
        self.numActions = 4
        self.optimizer = torch.optim.Adam(self.nn.parameters(), lr=self.LearningRate)
        #-------------counters-------------
        self.batchCounter = 0
        self.episodeRewards = []
        self.lossFunctionVals = []
        self.plotRewards = []
        self.totalRewards = []

    def discountRewards(self, rewards):
        '''Function for discounting an episode's rewards '''
        r = np.array([self.beta**i * rewards[i] for i in range(len(rewards))])
        r = r[::-1].cumsum()[::-1] # gather all awards and assign based on the corresponding state
        result = r - r.mean() # subtracting mean gives faster convergence
        return result

    
    def learn(self):
        '''Function for initiating the learning process. Gradient ascent steps and environment interactions
        take place here.'''
        self.start = time.time()
        self.currentEpisode = 0
        self.batchCounter = 0
        self.totalTimestep = 0
        batchRewards = []
        batchActions = []
        batchStates = []
        actionSpace = np.arange(self.numActions)

        while self.currentEpisode < self.numEpisodes:
            if self.currentEpisode in self.episodeRanges:
                self.close(self.currentEpisode) # first saved model is the untrained one (zero episodes)

            episodeStart = time.time()
            observation = self.env.reset()
            states = []
            rewards = []
            actions = []
            done = False

            while done == False:
                actionProbs = self.nn.forward(observation).detach().numpy()
                action = np.random.choice(actionSpace, p=actionProbs)
                nextState, reward, done, info = self.env.step(action)

                states.append(observation)
                rewards.append(reward)
                actions.append(action)

                observation = nextState
                self.totalTimestep += 1

                if done:
                    print(f"Finished Episode: {self.currentEpisode+1}")
                    self.totalRewards.append(sum(rewards))
                    batchRewards.extend(self.discountRewards(rewards))
                    batchStates.extend(states)
                    batchActions.extend(actions)
                    self.batchCounter += 1
                    self.currentEpisode += 1

                    # If batch is complete, update network
                    if self.batchCounter == self.batchSize:

                        self.optimizer.zero_grad()
                        stateBatchTensor = torch.FloatTensor(batchStates)  # use torch.cuda.FloatTensor for GPU
                        rewardBatchTensor = torch.FloatTensor(batchRewards)
                        actionBatchTensor = torch.LongTensor(batchActions)
                        logProb = torch.log(self.nn.forward(stateBatchTensor))

                        selectedLogProbs = rewardBatchTensor * torch.gather(logProb, 1, actionBatchTensor.unsqueeze(1)).squeeze()

                        loss = -selectedLogProbs.mean() # since we're doing gradient ascent, we multiply by -1
                        self.lossFunctionVals.append(loss.detach().numpy())

                        loss.backward() # backpropagation
                        self.optimizer.step() # modify parameters

                        print(f'did gradient descent step')
                        
                        batchRewards = []
                        batchActions = []
                        batchStates = []
                        self.batchCounter = 0

        self.end = time.time()
        self.close(self.numEpisodes)
        self.trainingEnding()
        print(f'---------------------------\nDONE. Time taken: {self.end - self.start:.5f} seconds.')
        print(f'total timesteps taken: {self.totalTimestep}')

    def close(self, episode):
        '''Function for saving the NN parameters at defined interval *episodeSaveInterval* '''

        directory=(f'{self.directory}'+f'seed_{self.seed}\
_lr_{self.LearningRate}_batchSize_{self.batchSize}_trainedNumEpisodes_{episode}')
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        torch.save(self.nn.state_dict(), directory+'/trained_model.pt')

    def trainingEnding(self):
        '''Function for saving training information once it is over.''' 

        file = open(self.directory+'trainingInfo.txt', 'w+')
        file.write(f'training time: {self.end - self.start:.5f} seconds\n')   
        file.write(f'training episodes: {self.numEpisodes}\n')  
        file.write(f'Mini-batch size: {self.batchSize}\n')
        file.write(f'Total timesteps: {self.totalTimestep}\n')  
        file.close()


