# Largely a copy pasete from minimal PPO Implementation
import torch
import torch.nn.functional as F
from bigmdp.utils.utils_directory import *
from torch import nn
from torch.distributions import Categorical
from torch import optim , cuda
import numpy as np
from sklearn.decomposition import PCA

# Hyperparameters

class DQN(nn.Module):
    def __init__(self, input_shape, action_size,use_cuda=False, learning_rate = 0.0001):
        super(DQN, self).__init__()
        self.data = []
        self.useCNN = len(input_shape) > 1
        self.use_cuda = use_cuda and cuda.is_available() if (use_cuda != None) else False


        if self.useCNN:
            self.model = nn.Sequential(
                nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
                nn.ReLU(),
                nn.Conv2d(32, 64, kernel_size=4, stride=2),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3, stride=1),
                nn.ReLU()
            )

            fc_size = self._get_conv_out(input_shape)
            self.fc1_net = nn.Linear(fc_size, 256)

        else:
            self.fc1_net = nn.Linear(input_shape[0],256)


        self.fc_v = nn.Linear(256, action_size)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        if self.use_cuda:
            self.cuda()

    def forward(self, x):
        if self.useCNN:
            x = self.model(x)
            x = x.view(x.shape[0], -1)
            x = nn.ReLU()(self.fc1_net(x))
            x = self.fc_v(x)
            return x
        else:
            x = nn.ReLU()(self.fc1_net(x))
            return self.fc_v(x)

    def encode(self, x):
        if self.useCNN:
            x = self.model(x)
            x = x.view(x.shape[0], -1)
            x = nn.ReLU()(self.fc1_net(x))
            # x = self.fc_v(x)
            return x
        else:
            x = nn.ReLU()(self.fc1_net(x))
            return x

    def _get_conv_out(self, shape):
        o = self.model(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

class DQNAgent():
    def __init__(self, input_shape, action_size, wandb_logger, use_cuda, double_dqn = True, learning_rate = 0.0001, gamma = 0.99, multiplyer=1):
        self.use_cuda = use_cuda
        self.tgt_net = DQN(input_shape, action_size, use_cuda= use_cuda, learning_rate = learning_rate)
        self.net = DQN(input_shape, action_size, use_cuda = use_cuda, learning_rate = learning_rate)
        self.wandb_logger = wandb_logger
        self.device = "cuda" if use_cuda else "cpu"
        self.double_dqn = double_dqn
        self.pca_fitted = False
        self.pca_flag = False
        self.gamma = gamma
        self.multiplyer = multiplyer


    def calc_loss(self, batch):
        net , tgt_net = self.net, self.tgt_net
        states, actions, next_states, rewards, dones = batch

        states_v = torch.FloatTensor(states).to(self.device)
        next_states_v = torch.FloatTensor( next_states).to(self.device)
        actions_v = torch.LongTensor(actions).to(self.device)
        rewards_v = torch.FloatTensor(rewards).to(self.device)
        done_mask = torch.BoolTensor(dones).to(self.device)

        state_action_values = net(states_v).gather( 1, actions_v)
        with torch.no_grad():
            if self.double_dqn:
                max_v, act_v = torch.max(net(states_v), dim= 1)
                next_state_values = tgt_net(states_v).gather(1, act_v )
            else:
                next_state_values = tgt_net(next_states_v).max(1)[0] # replace the max with ddqn update
            next_state_values[done_mask.squeeze()] = 0.0
            next_state_values = next_state_values.detach().unsqueeze(-1)

        expected_state_action_values = next_state_values * self.gamma + \
                                       rewards_v
        return nn.MSELoss()(state_action_values,
                            expected_state_action_values)

    def train_net(self, batch):
        self.net.optimizer.zero_grad()
        loss_t = self.calc_loss(batch)
        loss_t.backward()
        self.curr_loss = loss_t
        self.net.optimizer.step()

    def sync_target_net(self):
        self.tgt_net.load_state_dict(self.net.state_dict())


    def get_action(self, s, epsilon):
        state_v = torch.FloatTensor([s]).to(self.device)
        q_vals_v = self.net(state_v)
        batch_size, n_actions = q_vals_v.shape
        _, act_v = torch.max(q_vals_v, dim=1)
        action = int(act_v.item())
        if np.random.random()<epsilon:
            action = np.random.choice(list(range(n_actions)))
        return action
    
    def get_bcq_action(self, s, epsilon):
        with torch.no_grad():
            state_v = torch.FloatTensor([s]).to(self.device)
            q, imt, i = self.net(state_v)
            imt = imt.exp()
            imt = (imt/imt.max(1, keepdim=True)[0] > 0.3).float()
            # Use large negative number to mask actions from argmax
            action = int((imt * q + (1. - imt) * -1e8).argmax(1))
        if np.random.random()<epsilon:
            action = np.random.choice(list(range(n_actions)))
        return action

    def fit_pca(self, latent_size, data):
        self.pca = PCA(latent_size)
        self.pca = self.pca.fit(data)
        self.pca_fitted = True

    def encode_single(self, obs, do_pca=None):
        do_pca = do_pca or self.pca_flag
        obs_tensor = torch.FloatTensor([obs]).cuda() if self.use_cuda else torch.FloatTensor([obs])
        with torch.no_grad():
            s_tensor = self.net.encode(obs_tensor).squeeze().detach().cpu().numpy()
            s_tensor = (s_tensor * self.multiplyer).round(5)
            s = tuple(s_tensor.round(5))
            s = tuple(self.pca.transform(np.array([s]))[0]) if do_pca else s
        return s

    def encode_batch(self, obs_batch, do_pca=None):
        do_pca = do_pca or self.pca_flag
        obs_tensor = torch.FloatTensor(obs_batch).cuda() if self.use_cuda else torch.FloatTensor(obs_batch)
        with torch.no_grad():
            s_batch_tensor = self.net.encode(obs_tensor).squeeze().detach().cpu().numpy()
            s_batch_tensor = (s_batch_tensor * self.multiplyer).round(5)
            s_batch = [tuple(i) for i in s_batch_tensor]
            s_batch =  [tuple(i) for i in self.pca.transform(s_batch)] if do_pca else s_batch
        return s_batch

    def predict_single_transition(self, s, a):
        if self.warn_flag:
            print("#" * 50,
                  "\n Warning, prediction called for dqn network, this should not be happening, returning garbage \n",
                  "#" * 50)
            self.warn_flag = 0
        return 0, 0, 0

    def predict_batch_transition(self, s, a):
        if self.warn_flag:
            print("#" * 50,
                  "\n Warning, prediction called for dqn network, this should not be happening, returning garbage \n",
                  "#" * 50)
            self.warn_flag = 0
        return [0] * len(s), [0] * len(s), [0] * len(s)

    def state_dict(self):
        return self.net.state_dict()

    def load_state_dict(self, weights):
        self.net.load_state_dict(weights)
