import pandas as pd
import numpy as np
import copy
import random, math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical

from collections import namedtuple, deque

np.set_printoptions(precision=2)

verbose = 0

'''
   - A cyclic buffer memory to hold the transition tuples (state, action, next_state, reward) learner has experienced.
   - While updating the Q-network learner samples a batch dataset from the replay memory and 
     makes an update using that. Sampling insures batch are decorrelated and speed learning.

'''
class ReplayMemory(object):
    def __init__(self, capacity, Transition):
        self.memory = deque([], maxlen=capacity)
        self.transitions = Transition

    def push(self, *args):
        """save a transition"""
        self.memory.append(self.transitions(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
class QNet(nn.Module):
    def __init__(self, action_feature_dim, num_actions):
        super(QNet, self).__init__()
        self.action_feature_dim = action_feature_dim    # dimension of the feature phi(s,a) corresponding to every (s,a) tuple.
        self.n_actions = num_actions
        self.affine = nn.Linear(action_feature_dim, 1)  # every action share the same weight
        self.affine.weight.data = torch.zeros((1, action_feature_dim))
        
    def forward(self, x):
        outputs = []
        for i in range(self.n_actions):
            sub_x = x[:,i*self.action_feature_dim:(i+1)*self.action_feature_dim]
            sub_out = self.affine(sub_x)        # Q(s,a) is a linear function of phi(s,a)
            outputs.append(sub_out)

        outputs = torch.cat(outputs, dim=1) 
        return outputs 
    
class DQN():
    
    #def __init__(self, env, exp, all_data_df, weight_df, args):
    def __init__(self, env, exp, args):    
        self.transitions = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
        '''
            The agent uses epsilon greedy policy to select an action. The epsilon parameter starts with 'eps_start' value and ends with 
            'eps_end' value decaying by an exponential factor of 'eps_decay' after every time step.
        '''
        self.batch_size, self.gamma, self.eps_start, self.eps_end, self.eps_decay, self.target_update_freq = \
                            args['GRAD_BATCH_SIZE'], args['GAMMA'], args['EPS_START'], args['EPS_END'], args['EPS_DECAY'], args['TARGET_UPDATE_FREQ']
        #self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")
        self.n_actions, self.steps_done = env.action_space.n, 0
        self.train_horizon, self.lr, self.weight_save_freq = args['TRAIN_HORIZON'], args['LR'], args['WEIGHT_SAVE_FREQ']
        self.env, self.act, self.run_id = env, args['ACT'], args['RUN_ID']
        self.all_data_df = pd.DataFrame(columns=['episode', 'time', 'action_type', 'action', 'reward', 'done', 'other'])
        self.weight_df = pd.DataFrame(columns=['episode', 'time', 'board', 'weight', 'gradient'])
        #self.all_data_df, self.weight_df = all_data_df, weight_df
        self.exp, self.record = exp, args['RECORD']

        '''
            - Policy network is the actual q-netwrok whose weights get updated at every time step
            - We also use a target network to compute V(s_t+1), updating its weights only after every 'target_update' episodes stablizes the updates of policy network.
        '''
        self.policy_net = QNet(env.action_feature_dim, self.n_actions).to(self.device)       
        self.target_net = QNet(env.action_feature_dim, self.n_actions).to(self.device)      
        self.replay_memory = ReplayMemory(args['REPLAY_BUFFER_SIZE'], self.transitions)

        if(args['OPTIMIZER']=='ADAM'):
            self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
        elif(args['OPTIMIZER']=='SGD'):
            self.optimizer = optim.SGD(self.policy_net.parameters(), lr=self.lr)
        elif(args['OPTIMIZER']=='RMSPROP'):
            self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=self.lr)
        else:
            raise("Optimizer not specified")

        lr_gamma = 1
        if 'LR_GAMMA' in args.keys():
            lr_gamma = args['LR_GAMMA']
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, lr_gamma)            
        
        print('*'*50,'\nCode running on ', self.device,'\n')
        print('*'*50)

    def set_policy_net_weight(self, weights):
        self.policy_net.affine.weight.data = torch.Tensor(weights)

    def set_target_net_weight(self, weight):
        self.target_net.affine.weight.data = torch.Tensor(weights)
            
    def reward_mapping(self, old_reward):
        if old_reward == -1:
            return 0
        else:
            return 1

    def train(self, train_batch_begin_episode_index, train_batch_end_episode_index):
        # set the target network weights to policy network weight initially.
        self.target_net.load_state_dict(self.policy_net.state_dict())

        episode_ampp = -1
        for episode in range(train_batch_begin_episode_index, train_batch_end_episode_index):
            episode_return = 0
            state = self.env.reset()
            state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
            self.weight_df = self.weight_df.append({'episode':episode, 'time':-1, 'weight': '', \
                                        'board':copy.deepcopy(self.env.board), 'gradient' : ''}, ignore_index=True)
            for t in range(self.train_horizon):

                # select action
                act, action_type = self.select_action(state)
                action = int(act)

                next_state, reward, done, _ = self.env.step(action)
                
                #reward = self.reward_mapping(reward)

                episode_return += reward
                action, reward = torch.IntTensor([action]).to(self.device), torch.FloatTensor([reward]).to(self.device)
                next_state = torch.from_numpy(next_state).float().unsqueeze(0).to(self.device)

                if(done):   
                    next_state = None   # setting this to None so that V(.) is assigned a value of 0.
                
                self.replay_memory.push(state, action, next_state, reward)

                # save weights at only particular intervals
                if(episode < 200):
                    if(self.policy_net.affine.weight.grad == None):     # to handle the fact that gradient is calculated only after replay buffer has some data.
                        self.weight_df = self.weight_df.append({'episode':episode, 'time':t, 'weight': '', \
                                                            'board':copy.deepcopy(self.env.board), 'gradient' : ''}, ignore_index=True)
                    else:
                        self.weight_df = self.weight_df.append({'episode':episode, 'time':t, 'weight': copy.deepcopy(self.policy_net.affine.weight.cpu().data.numpy().tolist()), \
                                                            'board':copy.deepcopy(self.env.board), 'gradient' : copy.deepcopy(self.policy_net.affine.weight.grad.cpu().data.numpy().tolist())}, ignore_index=True)
                
                self.all_data_df = self.all_data_df.append({'episode':episode, 'time':t, 'action_type':action_type, 'action':int(action), 'reward':int(reward), 'done':done, 'other':'none'}, ignore_index=True)            
               
                state = next_state

                # update the network on a sampled batch after every time step
                self.optimize_model()
                
                if done:
                    episode_ampp = (t+1) / self.env.initial_object_count
                    break

                self.steps_done += 1

            # update the target network weight to the current policy network after every 'target_update' episodes.
            if (episode > 0 and (episode % self.target_update_freq == 0)):
                self.target_net.load_state_dict(self.policy_net.state_dict())

            
            if(episode%5==0):
                print("Run : ", self.run_id, "Episode ", episode, ' ended at ', t+1, "Return : ", episode_return)

                

    def select_action(self, state):
        sample = random.random()

        # epsilon greedy policy - epsilon decays exponentially with time
        eps_threshold = self.eps_end + (self.eps_start-self.eps_end)*math.exp(-1*self.steps_done/self.eps_decay) 
        
        if(self.act == 'EPS_GREEDY'):
            if sample > eps_threshold:
                with torch.no_grad():
                    action = (self.policy_net(state).max(1)[1]).to(self.device)    # greedy action
                    return action, 'greedy'
            else:
                action = torch.Tensor([[random.randrange(self.n_actions)]]).to(self.device)   # exploratory action
                return action, 'random'
        else:
            nnProbs = torch.softmax(self.policy_net(state), dim=1)
            dist = Categorical(nnProbs)
            action = dist.sample().view(-1,1).float()
            return action, ''


    def optimize_model(self):

        # if size of data in memory is smaller than the batch size do nothing
        if len(self.replay_memory) < self.batch_size:
            return

        # sample a batch transitions from the replay memory
        transitions = self.replay_memory.sample(self.batch_size)    # returns a list of transitions
        batch = self.transitions(*zip(*transitions))    # converts list of transitions into Transition of list - transition is a list

        # map is used to apply the lambda function to each element in the iterable 'batch.next_state' and it returns an iterable with the result
        # tuple converts the iterable to tuple type and then to torch.tensor type. Producing tensor stating which of the transitions have next state not None.
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=self.device, dtype=torch.bool)

        
        # collect next_states which are not not None
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action).type('torch.LongTensor').to(self.device)
        reward_batch = torch.cat(batch.reward)

        all_state_action_values = self.policy_net(state_batch)

        state_action_values = all_state_action_values.gather(1, action_batch.view(-1,1))
        next_state_values = torch.zeros(self.batch_size, device=self.device)

        # max along dimension dim return max tensor at index [0] and index tensor at index [1]
        next_state_values_for_non_final_states = self.target_net(non_final_next_states).max(dim=1)[0].detach() # detach the target since it does not require grad
        next_state_values[non_final_mask] =  next_state_values_for_non_final_states # the next state value for final_states remain 0.

        expected_state_action_values = (next_state_values * self.gamma) + reward_batch  # one-step target estimate from the batch data

        # use huber loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()  # zero the previous gradient
        loss.backward()             # calculate the gradient by backprop

        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)   # clamp the gradients between -1, 1
        self.optimizer.step()       # make a gradient update


if __name__ == "__main__":
    pass
