import numpy as np
import torch.nn as nn
import os
import torch 
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from collections import OrderedDict

class PPOMemory:
    def __init__(self, batch_size):
        self.states = []
        self.probs = []
        self.vals = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.labels=[]
        self.batch_size = batch_size

    def sample(self):
        batch_step = np.arange(0, len(self.states), self.batch_size)
        indices = np.arange(len(self.states), dtype=np.int64)
        np.random.shuffle(indices)
        batches = [indices[i:i+self.batch_size] for i in batch_step]
        return np.array(self.states),\
                np.array(self.actions),\
                np.array(self.probs),\
                np.array(self.vals),\
                np.array(self.rewards),\
                np.array(self.dones),\
                np.array(self.labels),\
                batches

    def push(self, state, action, probs, vals, reward, done, label):
        self.states.append(state)
        self.actions.append(action)
        self.probs.append(probs)
        self.vals.append(vals)
        self.rewards.append(reward)
        self.dones.append(done)
        self.labels.append(label)

    def clear(self):
        self.states = []
        self.probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.vals = []
        self.labels=[]

class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__(in_features, out_features, bias)        
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def forward(self, input):
        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return F.linear(input, W, b)

class Actor(nn.Module):
    def __init__(self,state_dim, action_dim, hidden_dim):
        super(Actor, self).__init__()
        self.actor = nn.Sequential(
            Linear(state_dim, hidden_dim),
            nn.ReLU(),
            Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        dist = self.actor(state)
        return dist

class Critic(nn.Module):
    def __init__(self, state_dim,hidden_dim):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            Linear(state_dim, hidden_dim*2),
            nn.ReLU(),
            Linear(hidden_dim*2, hidden_dim*2),
            nn.ReLU(),
            Linear(hidden_dim*2, 1)
        )

    def forward(self, state):
        value = self.critic(state)
        return value
    

class PPOAgent:
    def __init__(self,intersection_id, state_dim, action_dim,cfg,phase_list):
        self.intersection_id = intersection_id
        self.phase_list = phase_list
        self.gamma = cfg.gamma
        self.policy_clip = cfg.policy_clip
        self.n_epochs = cfg.n_epochs
        self.gae_lambda = cfg.gae_lambda
        self.device = cfg.device
        self.actor = Actor(state_dim, action_dim,cfg.hidden_dim).to(self.device)
        self.critic = Critic(state_dim,cfg.hidden_dim).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)
        self.memory = PPOMemory(cfg.batch_size)
        self.loss = 0

    def choose_action(self, observation):
        state = torch.tensor(observation, dtype=torch.float).to(self.device)
        dist = self.actor(state)
        value = self.critic(state)
        dist = Categorical(dist)
        action = dist.sample()
        probs = torch.squeeze(dist.log_prob(action)).item()
        action = torch.squeeze(action).item()
        value = torch.squeeze(value).item()
        return action, probs, value

    def remember(self,state, action, prob, val, reward, done,label):
        self.memory.push(state, action, prob, val, reward, done,label)

    def update_1(self,epoch):
        state_arr, action_arr, old_prob_arr, vals_arr, reward_arr, dones_arr, labels_arr, batches = self.memory.sample()
        values = vals_arr
        advantage = np.zeros(len(reward_arr), dtype=np.float32)
        for t in range(len(reward_arr)-1):
            discount = 1
            a_t = 0
            for k in range(t, len(reward_arr)-1):
                a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
                        (1-int(dones_arr[k])) - values[k])
                discount *= self.gamma*self.gae_lambda
            advantage[t] = a_t
        advantage = torch.tensor(advantage).to(self.device)
        values = torch.tensor(values).to(self.device)

        states = torch.tensor(state_arr[batches[0]], dtype=torch.float).to(self.device).squeeze(1)
        labels = torch.tensor(labels_arr[batches[0]],dtype=torch.long).to(self.device)
        dist = self.actor(states)
        actor_loss_1=nn.functional.cross_entropy(dist,labels)

        old_probs = torch.tensor(old_prob_arr[batches[0]], dtype=torch.float).to(self.device)
        actions = torch.tensor(action_arr[batches[0]], dtype=torch.float).to(self.device)
        dist = self.actor(states)
        dist = Categorical(dist)
        critic_value = self.critic(states)
        critic_value = torch.squeeze(critic_value)
        new_probs = dist.log_prob(actions)
        prob_ratio = new_probs.exp() / old_probs.exp()
        weighted_probs = advantage[batches[0]] * prob_ratio
        weighted_clipped_probs = torch.clamp(prob_ratio, 1-self.policy_clip,
                1+self.policy_clip)*advantage[batches[0]]
        actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()
        returns = advantage[batches[0]] + values[batches[0]]
        critic_loss = (returns-critic_value)**2
        critic_loss = critic_loss.mean()
        total_loss = 0.0001*(actor_loss + 0.5*critic_loss)+actor_loss_1*(1-0.0001)
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        total_loss.backward(retain_graph=True)
        self.actor_optimizer.step()
        self.critic_optimizer.step()
        return actor_loss_1
    
    def update_2(self,avg_loss):
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        avg_loss.backward(retain_graph=True)
        self.actor_optimizer.step()
        self.critic_optimizer.step()

    def save(self,path):
        actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
        critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
        torch.save(self.actor.state_dict(), actor_checkpoint)
        torch.save(self.critic.state_dict(), critic_checkpoint)

    def load(self,path):
        actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
        critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
        self.actor.load_state_dict(torch.load(actor_checkpoint)) 
        self.critic.load_state_dict(torch.load(critic_checkpoint))  

    def random_initialize(self,pruning_ratio):
        idx=0
        for name,module in self.actor.named_modules():
            if isinstance(module,Linear):
                num_zeros=int(pruning_ratio[idx]*module.weight_mask.numel())
                indices=torch.randperm(module.weight_mask.numel())[:num_zeros]
                module.weight_mask.view(-1)[indices]=0

                num_zeros=int(pruning_ratio[idx]*module.bias_mask.numel())
                indices=torch.randperm(module.bias_mask.numel())[:num_zeros]
                module.bias_mask.view(-1)[indices]=0

                idx+=1

        idx=0
        for name,module in self.critic.named_modules():
            if isinstance(module,Linear):
                num_zeros=int(pruning_ratio[idx]*module.weight_mask.numel())
                indices=torch.randperm(module.weight_mask.numel())[:num_zeros]
                module.weight_mask.view(-1)[indices]=0

                num_zeros=int(pruning_ratio[idx]*module.bias_mask.numel())
                indices=torch.randperm(module.bias_mask.numel())[:num_zeros]
                module.bias_mask.view(-1)[indices]=0

                idx+=1

class MPPOAgent(object):
    def __init__(self,
                 intersection,
                 state_size,
                 cfg,
                 phase_list
                 ):

        self.intersection = intersection
        self.n_epochs = cfg.n_epochs
        self.agents = {}
        self.make_agents(state_size, cfg, phase_list)

    def make_agents(self, state_size, cfg, phase_list):
        for id_ in self.intersection:
            self.agents[id_] = PPOAgent(intersection_id=id_,
                                        state_dim=state_size,
                                        action_dim=len(phase_list[id_]),
                                        cfg=cfg,
                                        phase_list=phase_list[id_],
                                        )

    def remember(self, state, action, prob, val, reward, done,labels):
        for id_ in self.intersection:
            self.agents[id_].remember(state[id_],
                                      action[id_],
                                      prob[id_],
                                      val[id_],
                                      reward[id_],
                                      done[id_],
                                      labels[id_],
                                      )

    def random_initialize(self,pruning_ratio):
        for id_ in self.intersection:
            self.agents[id_].random_initialize(pruning_ratio)

    def choose_action(self, state):
        action = {}
        prob={}
        val={}
        for id_ in self.intersection:
            action[id_],prob[id_], val[id_] = self.agents[id_].choose_action(state[id_])
        return action,prob,val

    def replay(self,epoch):
        for _ in range(self.n_epochs):
            losses=[]
            for id_ in self.intersection:
                loss=self.agents[id_].update_1(epoch)
                losses.append(loss)
            avg_loss=torch.mean(torch.stack(losses))
            for id_ in self.intersection:
                self.agents[id_].update_2(avg_loss)
        for id_ in self.intersection:
            self.agents[id_].memory.clear()
            
    def load(self, name):
        for id_ in self.intersection:
            self.agents[id_].load(name)
        print("\nloading model successfully!\n")

    def save(self,path):
        for id_ in self.intersection:
            self.agents[id_].save(path + "/")
