import torch
import numpy as np
import copy
import networks as nets


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SD2PC():
    def __init__(self,
                 state_dim,
                 action_dim,
                 action_slices=20,
                 net_width=256,
                 gamma=0.99,
                 batchsize=256,
                 c_lr=3e-4,
                 d_lr=3e-4,
                 target_entropy=1.5,
                 alpha_lr=3e-4,
                 tau=0.005,
                 ):

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_slices = action_slices

        self.action_slice = 2 / action_slices
        self.action_base = - self.action_slice * (action_slices - 1) / 2

        self.policy = nets.discrete_actor(state_dim, action_dim, action_slices, net_width).to(device)
        self.a_optimizer = torch.optim.Adam(self.policy.parameters(), lr=d_lr)
        self.target_policy = copy.deepcopy(self.policy)

        self.critic = nets.critic(state_dim, action_dim, net_width).to(device)
        self.c_optimizer = torch.optim.Adam(self.critic.parameters(), lr=c_lr)
        self.target_critic = copy.deepcopy(self.critic)

        self.log_alpha = torch.tensor(-5.0).to(device)
        self.log_alpha.requires_grad = True
        self.target_entropy = target_entropy
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)

        self.gamma = gamma
        self.bc = batchsize
        self.loss = torch.nn.MSELoss()
        self.tau = tau
        self.training_steps = 0



    def select_action(self, state, greedy=False):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        distribution = self.policy(state).reshape([self.action_dim, self.action_slices])
        if not greedy:
            for i in range(self.action_slices - 1):
                distribution[:, i + 1] += distribution[:, i]
            distribution_low = torch.cat([torch.zeros([self.action_dim, 1]).to(device), distribution[:, :-1]], dim=1)
            random_generation = torch.rand([self.action_dim, 1]).to(device)
            _, action_index = torch.max(
                (random_generation > distribution_low).float() * (random_generation < distribution).float(), dim=1)
            action = action_index * self.action_slice + self.action_base
            action = action.detach().cpu().numpy()
        else:
            _, action_index = torch.max(distribution, dim=1)
            action = action_index * self.action_slice + self.action_base
            action = action.detach().cpu().numpy()
        return action

    def train(self, replaybuffer):
        self.training_steps += 1
        with torch.no_grad():
            alpha = torch.exp(self.log_alpha).detach()
            s, a, r, s_, d = replaybuffer.sample(self.bc)
            target_distribution = self.target_policy(s_)
            target_logprob = torch.log(target_distribution / self.action_slice + 1e-7)
            target_entropy = torch.sum(target_logprob * target_distribution, dim=2).detach()
            for i in range(self.action_slices - 1):
                target_distribution[:, :, i+1] += target_distribution[:, :, i]
            target_distribution_low = torch.cat(
                [torch.zeros([self.bc, self.action_dim, 1]).to(device), target_distribution[:, :, :-1]], dim=2)
            random_generation = torch.rand([self.bc, self.action_dim, 1]).to(device)
            _, target_action_index = torch.max((random_generation > target_distribution_low).float() * (
                        random_generation < target_distribution).float(), dim=2)
            target_action = target_action_index * self.action_slice + self.action_base
            target_q1, target_q2 = self.target_critic(s_, target_action)
            target_q = torch.min(target_q1.squeeze(1), target_q2.squeeze(1)) - torch.sum(target_entropy, dim=1) * alpha

        #critic iteration
        target_q = r + self.gamma * (1 - d) * target_q

        current_q1, current_q2 = self.critic(s, a)
        c_loss = self.loss(current_q1.squeeze(1), target_q) + self.loss(current_q2.squeeze(1), target_q)
        self.c_optimizer.zero_grad()
        c_loss.backward()
        self.c_optimizer.step()

        #decomposed value iteration
        distribution = self.policy(s)
        logprob = torch.log(distribution / self.action_slice + 1e-7)
        entropy = torch.sum(logprob * distribution, dim=2)
        entropy_mean = torch.mean(entropy).detach()
        _, action_index = torch.max(distribution, dim=2)
        current_optimal_action = action_index * self.action_slice + self.action_base
        current_optimal_action = current_optimal_action.unsqueeze(1)

        action_chart = []
        for i in range(self.action_dim):
            for j in range(self.action_slices):
                specified_action = current_optimal_action.detach().clone()
                specified_action[:, :, i] = self.action_base + j * self.action_slice
                action_chart.append(specified_action)

        action_chart = torch.cat(action_chart, dim=1)
        action_chart = torch.reshape(action_chart,
                                     [self.bc * (self.action_slices * self.action_dim), self.action_dim])
        s_chart = s.unsqueeze(1).repeat(1, self.action_slices * self.action_dim, 1)
        s_chart = torch.reshape(s_chart, [self.bc * (self.action_slices * self.action_dim), self.state_dim])
        target_qsa1, target_qsa2 = self.critic(s_chart, action_chart)
        target_qsa = torch.min(target_qsa1, target_qsa2).detach()
        target_qsa = torch.reshape(target_qsa, [self.bc, self.action_dim, self.action_slices])
        a_loss = target_qsa * distribution
        a_loss = torch.sum(a_loss, dim=2) - entropy * alpha
        a_loss = -torch.mean(a_loss)

        self.a_optimizer.zero_grad()
        a_loss.backward()
        self.a_optimizer.step()

        alpha_generate = torch.exp(self.log_alpha)
        alpha_loss = alpha_generate * (self.target_entropy - entropy_mean)
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

        for target_param, param in zip(self.target_policy.parameters(), self.policy.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

class D3PC():
    def __init__(self,
                 state_dim,
                 action_dim,
                 action_slices=10,
                 net_width=256,
                 gamma=0.99,
                 batchsize=256,
                 c_lr=1e-3,
                 Q_lr=1e-3,
                 tau=0.005,
                 epsilon=0.1,
                 exploration_noise=0.1,
                 ):

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_slices = action_slices

        self.action_slice = 2 / action_slices
        self.action_base = - self.action_slice * (action_slices - 1) / 2

        self.discreteQ = nets.DeepQnetwork(state_dim, action_dim * action_slices, net_width).to(device)
        self.d_optimizer = torch.optim.Adam(self.discreteQ.parameters(), lr=Q_lr)
        self.target_discrete = copy.deepcopy(self.discreteQ)

        self.critic = nets.critic(state_dim, action_dim, net_width).to(device)
        self.c_optimizer = torch.optim.Adam(self.critic.parameters(), lr=c_lr)
        self.target_critic = copy.deepcopy(self.critic)

        self.gamma = gamma
        self.epsilon = epsilon
        self.bc = batchsize
        self.loss = torch.nn.MSELoss()
        self.tau = tau
        self.exploration_noise = exploration_noise
        self.eps_decay = 1 - 5e-6

        self.training_steps = 0


    def select_action(self, state, greedy=False):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        q = self.discreteQ(state)
        qvalue = q.reshape([self.action_dim, self.action_slices])
        _, action_index = torch.max(qvalue, dim=1)
        action_index = action_index.detach().cpu().numpy().astype(np.float32)
        if not greedy:
            random_sign = (np.random.uniform(0, 1, self.action_dim) < self.epsilon).astype(np.float32)
            random_index = np.random.randint(0, self.action_slices, self.action_dim)
            action_index = random_sign * random_index + (1 - random_sign) * action_index
            action = action_index * self.action_slice + self.action_base
            noise = np.random.normal(0, self.exploration_noise, self.action_dim)
            action = np.clip(action + noise, -1, 1)
        else:
            action = action_index * self.action_slice + self.action_base
            self.epsilon *= self.eps_decay
        return action, action_index

    def train(self, replaybuffer):
        self.training_steps += 1
        with torch.no_grad():
            s, a, r, s_, d = replaybuffer.sample(self.bc)
            target_q = self.target_discrete(s_)
            target_q = torch.reshape(target_q, [self.bc, self.action_dim, self.action_slices])
            _, max_index = torch.max(target_q, dim=2)
            target_action = max_index * self.action_slice + self.action_base
            target_q1, target_q2 = self.target_critic(s_, target_action)
            target_q = torch.min(target_q1, target_q2)

        target_q = r + self.gamma * (1 - d) * target_q.squeeze(1)

        current_q1, current_q2 = self.critic(s, a)
        c_loss = self.loss(current_q1.squeeze(1), target_q) + self.loss(current_q2.squeeze(1), target_q)
        self.c_optimizer.zero_grad()
        c_loss.backward()
        self.c_optimizer.step()

        qsa_generate_original = self.discreteQ(s)
        qsa_generate = torch.reshape(qsa_generate_original.detach().clone(),
                                     [self.bc, self.action_dim, self.action_slices])

        _, current_optimal_action_index = torch.max(qsa_generate, dim=2)
        current_optimal_action = current_optimal_action_index * self.action_slice + self.action_base
        current_optimal_action = current_optimal_action.detach().clone().unsqueeze(1)

        action_chart = []
        for i in range(self.action_dim):
            for j in range(self.action_slices):
                specified_action = current_optimal_action.detach().clone()
                specified_action[:, :, i] = self.action_base + j * self.action_slice
                action_chart.append(specified_action)

        action_chart = torch.cat(action_chart, dim=1)
        action_chart = torch.reshape(action_chart,
                                     [self.bc * self.action_slices * self.action_dim, self.action_dim])
        s_chart = s.unsqueeze(1).repeat(1, self.action_slices * self.action_dim, 1)
        s_chart = torch.reshape(s_chart, [self.bc * self.action_slices * self.action_dim, self.state_dim])
        target_qsa1, target_qsa2 = self.critic(s_chart, action_chart)
        target_qsa = torch.min(target_qsa1, target_qsa2).detach()
        target_qsa = torch.reshape(target_qsa, [self.bc, self.action_dim * self.action_slices])
        qsa_loss = self.loss(qsa_generate_original, target_qsa)
        self.d_optimizer.zero_grad()
        qsa_loss.backward()
        self.d_optimizer.step()

        for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        for target_param, param in zip(self.target_discrete.parameters(), self.discreteQ.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

class QPC():
    def __init__(self,
                 state_dim,
                 action_dim,
                 action_slices=10,
                 net_width=256,
                 gamma=0.99,
                 batchsize=256,
                 c_lr=1e-3,
                 d_lr=1e-3,
                 a_lr=1e-3,
                 tau=0.005,
                 delay_freq=2,
                 policy_smooth_noise=0.1,
                 epsilon=0.05,
                 exploration_noise=0.1
                 ):

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_slices = action_slices

        self.action_slice = 2 / action_slices
        self.action_base = - self.action_slice * (action_slices - 1) / 2

        self.discreteQ = nets.DeepQnetwork(state_dim, action_dim * action_slices, net_width).to(device)
        self.d_optimizer = torch.optim.Adam(self.discreteQ.parameters(), lr=d_lr)
        self.target_discrete = copy.deepcopy(self.discreteQ)

        self.critic = nets.critic(state_dim, action_dim, net_width).to(device)
        self.c_optimizer = torch.optim.Adam(self.critic.parameters(), lr=c_lr)
        self.target_critic = copy.deepcopy(self.critic)

        self.continuouspolicy = nets.TD3_Actor(state_dim, action_dim, net_width).to(device)
        self.a_optimizer = torch.optim.Adam(self.continuouspolicy.parameters(), lr=a_lr)
        self.target_continuous = copy.deepcopy(self.continuouspolicy)

        self.gamma = gamma
        self.epsilon = epsilon
        self.bc = batchsize
        self.loss = torch.nn.MSELoss()
        self.tau = tau
        self.delay_freq = delay_freq
        self.exploration_noise = exploration_noise
        self.policy_smooth_noise = policy_smooth_noise

        self.training_steps = 0
        self.discrete_initial_step = 1e5
        self.beta_decay = 1 - 5e-6
        self.eps_decay = 1 - 5e-6
        self.beta = 1
        self.beta_min = 0.5

    def select_action(self, state, greedy=False):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        q = self.discreteQ(state)
        qvalue = q.reshape([self.action_dim, self.action_slices])
        _, action_index = torch.max(qvalue, dim=1)
        action_index = action_index.detach().cpu().numpy().astype(np.float32)
        value_action = action_index * self.action_slice + self.action_base
        policy_action = self.continuouspolicy(state).reshape(self.action_dim).detach().cpu().numpy()
        action = self.beta * value_action + (1 - self.beta) * policy_action
        if not greedy:
            random_sign = (np.random.uniform(0, 1, self.action_dim) < self.epsilon).astype(np.float32)
            random_index = np.random.randint(0, self.action_slices, self.action_dim)
            random_action = random_index * self.action_slice + self.action_base
            action = random_action * random_sign + (1 - random_sign) * action
            noise = np.random.normal(0, self.exploration_noise, self.action_dim)
            action = np.clip(action + noise, -1, 1)
        self.epsilon *= self.eps_decay
        return action, action_index

    def train(self, replaybuffer):
        self.training_steps += 1
        if self.training_steps > self.discrete_initial_step:
            self.beta = self.beta_min + (self.beta_decay * (self.beta - self.beta_min))
        with torch.no_grad():
            s, a, r, s_, d = replaybuffer.sample(self.bc)
            target_q = self.target_discrete(s_)
            target_q = torch.reshape(target_q, [self.bc, self.action_dim, self.action_slices])
            _, max_index = torch.max(target_q, dim=2)
            target_action = max_index * self.action_slice + self.action_base
            target_q1, target_q2 = self.target_critic(s_, target_action)
            target_q = torch.min(target_q1, target_q2)
            if self.training_steps > self.discrete_initial_step:
                target_ca = self.target_continuous(s_)
                target_noise = torch.clamp(torch.randn_like(target_action) * self.policy_smooth_noise, -0.5, 0.5)
                target_ca = torch.clamp(target_ca + target_noise, -1, 1)
                target_cq1, target_cq2 = self.target_critic(s_, target_ca)
                target_cq = torch.min(target_cq1, target_cq2)
                target_q = self.beta * target_q + (1 - self.beta) * target_cq


        target_q = r + self.gamma * (1 - d) * target_q.squeeze(1)

        current_q1, current_q2 = self.critic(s, a)
        c_loss = self.loss(current_q1.squeeze(1), target_q) + self.loss(current_q2.squeeze(1), target_q)
        self.c_optimizer.zero_grad()
        c_loss.backward()
        self.c_optimizer.step()

        qsa_generate_original = self.discreteQ(s)
        qsa_generate = torch.reshape(qsa_generate_original.detach().clone(),
                                     [self.bc, self.action_dim, self.action_slices])

        _, current_optimal_action_index = torch.max(qsa_generate, dim=2)
        current_optimal_action = current_optimal_action_index * self.action_slice + self.action_base
        current_optimal_action = current_optimal_action.detach().clone().unsqueeze(1)

        action_chart = []
        for i in range(self.action_dim):
            for j in range(self.action_slices):
                specified_action = current_optimal_action.detach().clone()
                specified_action[:, :, i] = self.action_base + j * self.action_slice
                action_chart.append(specified_action)

        action_chart = torch.cat(action_chart, dim=1)
        action_chart = torch.reshape(action_chart,
                                     [self.bc * self.action_slices * self.action_dim, self.action_dim])
        s_chart = s.unsqueeze(1).repeat(1, self.action_slices * self.action_dim, 1)
        s_chart = torch.reshape(s_chart, [self.bc * self.action_slices * self.action_dim, self.state_dim])
        target_qsa1, target_qsa2 = self.critic(s_chart, action_chart)
        target_qsa = torch.min(target_qsa1, target_qsa2).detach()
        target_qsa = torch.reshape(target_qsa, [self.bc, self.action_dim * self.action_slices])
        qsa_loss = self.loss(qsa_generate_original, target_qsa)
        self.d_optimizer.zero_grad()
        qsa_loss.backward()
        self.d_optimizer.step()
        if self.training_steps % self.delay_freq == 0:
            cq_generate, _ = self.critic(s, self.continuouspolicy(s))
            a_loss = -torch.mean(cq_generate)
            self.a_optimizer.zero_grad()
            a_loss.backward()
            self.a_optimizer.step()
        for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        for target_param, param in zip(self.target_continuous.parameters(), self.continuouspolicy.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        for target_param, param in zip(self.target_discrete.parameters(), self.discreteQ.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

