from torch import nn
import torch
from utils.utils import layer_init, Discretizer, ActionConverter
from torch.distributions import Categorical


class Agent(nn.Module):
    def __init__(self, envs, image = True, safety = False, trade = False, cage = False):
        super().__init__()
        if image:
            self.network = nn.Sequential(
                layer_init(nn.Conv2d(4, 32, 8, stride=4)),
                nn.ReLU(),
                layer_init(nn.Conv2d(32, 64, 4, stride=2)),
                nn.ReLU(),
                layer_init(nn.Conv2d(64, 64, 3, stride=1)),
                nn.ReLU(),
                nn.Flatten(),
                layer_init(nn.Linear(64 * 7 * 7, 512)),
                nn.ReLU(),
            )
            self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)
            self.critic = layer_init(nn.Linear(512, 1), std=1)
            self.norm = 255
            self.n_actions = envs.single_action_space.n
        elif safety:
            self.safety = True
            self.discretizer = Discretizer(torch.tensor([[0,0], [1, 0], [0, 1], [1, 1]]))
            #self.discretizer = Discretizer(torch.tensor([[0,0], [-1, 0], [1, 0], [0, -1], [0, 1], [-1, 1], [-1, -1], [1, -1], [1, 1]]))
            obs_space = envs.single_observation_space.shape[0]
            print(obs_space)
            self.network = nn.Sequential(
                layer_init(nn.Linear(obs_space, 256)),
                nn.ReLU(),
                layer_init(nn.Linear(256, 256)),
                nn.ReLU(),
            )
            self.norm = 1
            self.actor = layer_init(nn.Linear(256, len(self.discretizer)), std=0.01)
            self.critic = layer_init(nn.Linear(256, 1), std=1)
            self.n_actions = len(self.discretizer)
        elif trade:
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                layer_init(nn.Linear(obs_space, 64)),
                nn.ReLU(),
                layer_init(nn.Linear(64, 64)),
                nn.ReLU(),
            )
            self.norm = 1
            self.actor = layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01)
            self.critic = layer_init(nn.Linear(64, 1), std=1)
            self.n_actions = envs.single_action_space.n
        elif cage:
            action_space = torch.tensor([1,133, 134, 135, 139,3, 4, 5, 9,16, 17, 18, 22,11, 12, 13, 14,141, 142, 143, 144,132,2,15, 24, 25, 26, 27])
            obs_space = envs.single_observation_space.shape[0]
            print(obs_space)
            self.network = nn.Sequential(
                layer_init(nn.Linear(obs_space, 64)),
                nn.ReLU(),
                layer_init(nn.Linear(64, 64)),
                nn.ReLU(),
            )
            self.norm = 1
            self.actor = layer_init(nn.Linear(64,  len(action_space)), std=0.01)
            self.critic = layer_init(nn.Linear(64, 1), std=1)
            self.act = ActionConverter(action_space)
            self.n_actions = len(action_space)

    def get_value(self, x):
        return self.critic(self.network(x / self.norm))
    
    def get_action_dist(self, x):
        hidden = self.network(x / self.norm)
        logits = self.actor(hidden)
        #print(logits)
        probs = Categorical(logits=logits)
        return probs.probs


    def get_action_and_value(self, x, action=None):
        hidden = self.network(x / self.norm)
        logits = self.actor(hidden)
        #print(hidden)
        try:
            probs = Categorical(logits=logits)
        except:
            import sys
            print("NaN Error")
            print(x)
            print("--")
            print(hidden)
            print("--")
            print(logits)
            sys.exit()
        if action is None:
            action = probs.sample()
        #if self.safety:
        #    return self.discretizer(action), probs.log_prob(action), probs.entropy(), self.critic(hidden)
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
    def __init__(self, envs, image, safety, trade, cage):
        super().__init__()
        if image:
            self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, envs.single_action_space.n),
            )
            self.norm = 255
            self.n_actions = envs.single_action_space.n
        elif safety:
            self.safety = True
            self.discretizer = Discretizer(torch.tensor([[0,0], [1, 0], [0, 1], [1, 1]]))
            self.n_actions = len(self.discretizer)
            #self.discretizer = Discretizer(torch.tensor([[0,0], [-1, 0], [1, 0], [0, -1], [0, 1], [-1, 1], [-1, -1], [1, -1], [1, 1]]))
            obs_space = envs.single_observation_space.shape[0]
            print(obs_space)
            self.network = nn.Sequential(
                nn.Linear(obs_space, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, len(self.discretizer))
            )
            self.norm = 1
        elif trade:
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, envs.single_action_space.n)
            )
            self.norm = 1
        elif cage:
            action_space = torch.tensor([1, 133, 134, 135, 139,3, 4, 5, 9,16, 17, 18, 22,11, 12, 13, 14,141, 142, 143, 144,132,2,15, 24, 25, 26, 27])
            obs_space = envs.single_observation_space.shape[0]
            self.n_actions = len(action_space)
            print(obs_space)
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, len(action_space))
            )
            self.norm = 1
            self.act = ActionConverter(action_space)

    def forward(self, x):
        return self.network(x / self.norm)