##########################################################################################
#### https://github.com/pytorch/examples/blob/main/reinforcement_learning/reinforce.py ###
##########################################################################################


import argparse
import gym
import numpy as np
from itertools import count
import math 
import os


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
# from simple_maze import SimpleMaze
from concave_maze import SimpleMazeConcave

from io_utils import get_result_folder, dump_output, get_result_path, save_model

import mosgd
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=1., metavar='G',
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='interval between training status logs (default: 10)')
args = parser.parse_args()

env = SimpleMazeConcave(intrinsic_reward=True, flip_rewards=True)
# env.seed(args.seed)
# torch.manual_seed(args.seed)


class Policy(nn.Module):
    def __init__(self, temp=1):
        super(Policy, self).__init__()
        self.temp = temp
        self.affine1 = nn.Linear(15, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 4)

        self.saved_log_probs = []
        self.rewards = [[],[]]

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores/self.temp, dim=1)


policy = Policy(temp=100)
optimizer = optim.Adam(policy.parameters(), lr=1e-3)
eps = np.finfo(np.float32).eps.item()

desired_policy_dict = {"Fast":{"threshold": -7.5, "stopping_point_1": -4.05, "stopping_point_2": -4.05},
                       "Mid":{"threshold": -3.5, "stopping_point_1": -3.05, "stopping_point_2": -6.05},
                       "Careful":{"threshold": -1., "stopping_point_1": 0.95, "stopping_point_2": -8.05}
                      }

desired_policy = "Careful"
active_const = True

def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()


def finish_episode():
    gs = []
    fs = []
    for obj in range(2):
        R = 0
        policy_loss = []
        returns = []
        undiscounted_return = 0
        for r in policy.rewards[obj][::-1]:
            R = r + args.gamma * R
            returns.insert(0, R)
            undiscounted_return += r
        fs.append(undiscounted_return)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + eps)
        for log_prob, R in zip(policy.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)
        optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward(retain_graph= obj < len(policy.rewards)-1)
        grads = [param.grad.detach().numpy() for param in policy.parameters()]
        grads_flat = [g.flatten() for g in grads]
        g = np.concatenate(grads_flat)
        gs.append(g)
        policy.zero_grad()
        del policy.rewards[obj][:]
        
    d = mosgd.direction(gs, fs, ths=[desired_policy_dict[desired_policy]["threshold"]], active_const=active_const, delta=math.pi/18)
    if not d is None:
        direction_by_layer = []
        ind = 0
        for param in policy.parameters():
            direction_by_layer.append(torch.from_numpy(d[ind:ind+torch.numel(param)].reshape(param.shape)))
            ind += torch.numel(param)
    #     print("Layered again: ", [d.shape for d in direction_by_layer])
        with torch.no_grad():
            for param, layer_grad in zip(policy.parameters(), direction_by_layer):
    #             print(f"{param=}")
                param.grad = layer_grad
        optimizer.step()    
    del policy.saved_log_probs[:]


def main():
    
    result_folder = get_result_folder()    
    result_path = get_result_path(os.path.join(result_folder, "rewards.pkl"))

    running_reward1 = -10
    running_reward2 = -10
    reward_lists = {"running_obj1":[], "running_obj2":[],
                    "ep_obj1":[], "ep_obj2":[],
                   "eval_obj1":[], "eval_obj2":[]}
    #     highest_reward1 = -100000
    for i_episode in count(1):
        state, ep_reward1, ep_reward2 = env.reset(), 0, 0
        for t in range(1, 100):  # Don't infinite loop while learning
            action = select_action(state)
            state, reward1, reward2, done, _ = env.step(action)
            if args.render:
                env.render()
            policy.rewards[0].append(reward1)
            policy.rewards[1].append(reward2)
            ep_reward1 += reward1
            ep_reward2 += reward2
            if done:
                break
#         highest_reward1 = max(highest_reward1, ep_reward1)

        running_reward1 = 0.05 * ep_reward1 + (1 - 0.05) * running_reward1
        running_reward2 = 0.05 * ep_reward2 + (1 - 0.05) * running_reward2
        finish_episode()
        
        reward_lists["running_obj1"].append(running_reward1)
        reward_lists["running_obj2"].append(running_reward2)
        reward_lists["ep_obj1"].append(ep_reward1)
        reward_lists["ep_obj2"].append(ep_reward2)
        
        if i_episode % args.log_interval == 0:
            print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f} \tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
                  i_episode, ep_reward1, running_reward1,  ep_reward2, running_reward2))
            
            if i_episode % (5*args.log_interval) == 0:
#                 result_path = get_result_path(os.path.join(result_folder, "rewards.pkl"))
#                 model_path = get_result_path(os.path.join(result_folder, "model.pkl"))
                dump_output(reward_lists, result_path)
#                 save_model(policy.state_dict(), model_path)
            
        if running_reward1 > desired_policy_dict[desired_policy]["stopping_point_1"] and running_reward2 > desired_policy_dict[desired_policy]["stopping_point_2"]:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward1, t))
            break


if __name__ == '__main__':
    main()