"""
A2C-GNN
-------
This file contains the A2C-GNN specifications. In particular, we implement:
(1) GNNParser
    Converts raw environment observations to agent inputs (s_t).
(2) GNNActor:
    Policy parametrized by Graph Convolution Networks (Section III-C in the paper)
(3) GNNCritic:
    Critic parametrized by Graph Convolution Networks (Section III-C in the paper)
(4) A2C:
    Advantage Actor Critic algorithm using a GNN parametrization for both Actor and Critic.
"""

import numpy as np 
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Dirichlet
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.utils import grid
from collections import namedtuple
from src.algos.opt_solver import solveOpt, RegsolveOpt
from src.misc.utils import dictsum
import json
from collections import defaultdict

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
args = namedtuple('args', ('render', 'gamma', 'log_interval'))
args.render= True
args.gamma = 0.97
args.log_interval = 10

#########################################
############## PARSER ###################
#########################################

class GNNParser():
    """
    Parser converting raw environment observations to agent inputs (s_t).
    """
    def __init__(self, env, T=10, grid_h=4, grid_w=4, scale_factor=0.01):
        super().__init__()
        self.env = env
        self.T = T
        self.s = scale_factor
        self.grid_h = grid_h
        self.grid_w = grid_w
        
    def parse_obs(self, obs, edge_index):
        x = torch.cat((
            torch.tensor([obs[0][n][self.env.time+1]*self.s for n in self.env.region]).view(1, 1, self.env.nregion).float(), 
            torch.tensor([[(obs[0][n][self.env.time+1] + self.env.dacc[n][t])*self.s for n in self.env.region] \
                          for t in range(self.env.time+1, self.env.time+self.T+1)]).view(1, self.T, self.env.nregion).float(), 
            torch.tensor([[sum([(self.env.scenario.demand_input[i,j][t])*(self.env.price[i,j][t])*self.s \
                          for j in self.env.region if i != j]) for i in self.env.region] for t in range(self.env.time+1, self.env.time+self.T+1)]).view(1, self.T, self.env.nregion).float()),
              dim=1).squeeze(0).view(21, self.env.nregion).T
        # edge_index, pos_coord = grid(height=self.grid_h, width=self.grid_w)
        # edge_index, pos_coord = grid(height=1, width=18)
        
        data = Data(x, edge_index)
        return data
    

class Scalar(nn.Module):
    def __init__(self, init_value):
        super().__init__()
        self.constant = nn.Parameter(
            torch.tensor(init_value, dtype=torch.float32))

    def forward(self):
        return self.constant

    
#########################################
############## ACTOR ####################
#########################################
class GNNActor(nn.Module):
    """
    Actor \pi(a_t | s_t) parametrizing the concentration parameters of a Dirichlet Policy.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv1 = GCNConv(in_channels, in_channels)
        self.lin1 = nn.Linear(in_channels, 32)
        self.lin2 = nn.Linear(32, 32)
        self.lin3 = nn.Linear(32, 1)
    
    def forward(self, data):
        out = F.relu(self.conv1(data.x, data.edge_index))
        x = out + data.x
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x

#########################################
############## CRITIC ###################
#########################################

class GNNCritic(nn.Module):
    """
    Critic parametrizing the value function estimator V(s_t).
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv1 = GCNConv(in_channels, in_channels)
        self.lin1 = nn.Linear(in_channels, 32)
        self.lin2 = nn.Linear(32, 32)
        self.lin3 = nn.Linear(32, 1)
    
    def forward(self, data):
        out = F.relu(self.conv1(data.x, data.edge_index))
        x = out + data.x 
        x = torch.sum(x, dim=0)
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x

#########################################
############## A2C AGENT ################
#########################################

class A2C(nn.Module):
    """
    Advantage Actor Critic algorithm for the AMoD control problem. 
    """
    def __init__(self, env, input_size, eps=np.finfo(np.float32).eps.item(), device=torch.device("cpu")):
        super(A2C, self).__init__()
        self.env = env
        self.eps = eps
        self.input_size = input_size
        self.hidden_size = input_size
        self.device = device
        
        self.actor = GNNActor(self.input_size, self.hidden_size)
        self.critic = GNNCritic(self.input_size, self.hidden_size)
        self.obs_parser = GNNParser(self.env)
        
        self.optimizers = self.configure_optimizers()
        
        # action & reward buffer
        self.saved_actions = []
        self.rewards = []
        self.to(self.device)
        
    def forward(self, obs, edge_index, jitter=1e-20):
        """
        forward of both actor and critic
        """
        # parse raw environment data in model format
        x = self.parse_obs(obs, edge_index).to(self.device)
        
        # actor: computes concentration parameters of a Dirichlet distribution
        a_out = self.actor(x)
        concentration = F.softplus(a_out).reshape(-1) + jitter

        # critic: estimates V(s_t)
        value = self.critic(x)
        return concentration, value 
    
    def parse_obs(self, obs, edge_index):
        state = self.obs_parser.parse_obs(obs, edge_index)
        return state
    
    def select_action(self, obs, edge_index):
        concentration, value = self.forward(obs, edge_index)
        
        m = Dirichlet(concentration)
        
        # action = m.sample()
        action = torch.tensor(concentration / concentration.sum())
        self.saved_actions.append(SavedAction(m.log_prob(action), value))
        return list(action.cpu().numpy())

    def training_step(self):
        R = 0
        saved_actions = self.saved_actions
        policy_losses = [] # list to save actor (policy) loss
        value_losses = [] # list to save critic (value) loss
        returns = [] # list to save the true values

        # calculate the true value using rewards returned from the environment
        for r in self.rewards[::-1]:
            # calculate the discounted value
            R = r + args.gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + self.eps)

        for (log_prob, value), R in zip(saved_actions, returns):
            advantage = R - value.item()

            # calculate actor (policy) loss 
            policy_losses.append(-log_prob * advantage)

            # calculate critic (value) loss using L1 smooth loss
            value_losses.append(F.smooth_l1_loss(value, torch.tensor([R]).to(self.device)))

        # take gradient steps
        self.optimizers['a_optimizer'].zero_grad()
        a_loss = torch.stack(policy_losses).sum()
        a_loss.backward()
        self.optimizers['a_optimizer'].step()
        
        self.optimizers['c_optimizer'].zero_grad()
        v_loss = torch.stack(value_losses).sum()
        v_loss.backward()
        self.optimizers['c_optimizer'].step()
        
        # reset rewards and action buffer
        del self.rewards[:]
        del self.saved_actions[:]
    
    def configure_optimizers(self):
        optimizers = dict()
        actor_params = list(self.actor.parameters())
        critic_params = list(self.critic.parameters())
        # optimizers['a_optimizer'] = torch.optim.Adam(actor_params, lr=5e-4, weight_decay = 1e-4)
        # optimizers['c_optimizer'] = torch.optim.Adam(critic_params, lr=5e-4, weight_decay = 1e-4)
        optimizers['a_optimizer'] = torch.optim.Adam(actor_params, lr=4e-4)
        optimizers['c_optimizer'] = torch.optim.Adam(critic_params, lr=4e-4)
        return optimizers
    
    def save_checkpoint(self, path='ckpt.pth'):
        checkpoint = dict()
        checkpoint['model'] = self.state_dict()
        for key, value in self.optimizers.items():
            checkpoint[key] = value.state_dict()
        torch.save(checkpoint, path)
        
    def load_checkpoint(self, path='ckpt.pth'):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model'])
        for key, value in self.optimizers.items():
            self.optimizers[key].load_state_dict(checkpoint[key])
    
    def log(self, log_dict, path='log.pth'):
        torch.save(log_dict, path)

    def test_agent(self, test_episodes, env, cplexpath, directory, theta_f, theta_g, mu, choice, max_steps, edge_index, cost_ls, price_ls, demandTime_ls, beta):
        epochs = range(test_episodes)  # epoch iterator
        episode_reward = []
        episode_served_demand = []
        episode_rebalancing_cost = []
        para = np.zeros((max_steps, self.env.nregion))
        price = defaultdict(dict)
        ind = 0
        for i, j in env.edges:
            for t in range(max_steps+10):
                price[i,j][t] = price_ls[t][ind]
            ind += 1
        env.price = price
        for _ in epochs:
            eps_reward = 0
            eps_served_demand = 0
            eps_rebalancing_cost = 0
            obs = env.reset()
            actions = []
            done = False
            step = 0
            while (not done):
                # obs, paxreward, done, info, _, _ = env.pax_step(
                #     CPLEXPATH=cplexpath, PATH="scenario_nyc4_test", directory=directory)
                # eps_reward += paxreward
                obs = env.Initial_step()
                # o = self.parse_obs(obs, self.device)
                # o = self.parse_obs(obs)

                action_rl = self.select_action(
                    obs, edge_index)
                # actions.append(action_rl)
                # print('beta in {} is: {}'.format(env.time, np.round(np.array(action_rl),2)))

                desiredAcc = {env.region[i]: int(
                    action_rl[i] * dictsum(env.acc, env.time + 1))for i in range(len(env.region))}
                sorted_dict = dict(sorted(desiredAcc.items()))            
                desiredAcc = np.array(list(sorted_dict.values()))

                if choice == 1:
                    paxAction, rebAction = solveOpt(env, desiredAcc, cost_ls, price_ls, demandTime_ls, beta)
                elif choice == 2:
                    paxAction, rebAction = RegsolveOpt(env, desiredAcc, theta_f, theta_g, mu, cost_ls, price_ls, demandTime_ls, beta)
                _, reward, done, info, _, _ = env.step(paxAction, rebAction, max_steps, cost_ls, price_ls, demandTime_ls)
                self.rewards.append(reward)

                # rebAction = solveRebFlow(
                #     env, "scenario_nyc4_test", desiredAcc, cplexpath, directory)
                # _, rebreward, done, info, _, _ = env.reb_step(rebAction)

                eps_reward += reward

                eps_served_demand += info["served_demand"]
                eps_rebalancing_cost += info["rebalancing_cost"]
                para[step, :] = desiredAcc
                step += 1
            episode_reward.append(eps_reward)
            episode_served_demand.append(eps_served_demand)
            episode_rebalancing_cost.append(eps_rebalancing_cost)

        return (
            np.mean(episode_reward),
            np.mean(episode_served_demand),
            np.mean(episode_rebalancing_cost),
            para
        )