import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from repaly_buffer import ReplayBuffer
from point_env import PointEnvMultiStepTwogoal


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 256)
        self.l4 = nn.Linear(256, state_dim)

    def forward(self, state, action):
        sa = torch.cat([state, action], -1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = F.relu(self.l3(q1))
        q1 = self.l4(q1)
        return q1


class PolicyG(object):
    """
    policy network trainer
    """

    def __init__(self, state_dim, action_dim, max_action, device,
                 discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.total_it = 0

    def train(self, replay_buffer, batch_size=256):
        self.total_it += 1
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
        # Get current Q estimates
        pre_state = self.critic(next_state, action)
        # Compute critic loss
        critic_loss = F.mse_loss(pre_state, state)
        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)
        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)


def main():
    max_episode_steps = 12
    max_steps = 100000
    start_timesteps = 1000
    batch_size = 256
    episode_reward = 0
    episode_timesteps = 0
    episode_num = 0
    step_idx = 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --------  define env --------
    env = PointEnvMultiStepTwogoal()
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]
    max_action = float(env.action_space.high[0])

    # --------  define policy --------
    policy = PolicyG(state_dim, action_dim, max_action, device)

    # --------  define replay_buffer --------
    replay_buffer = ReplayBuffer(state_dim, action_dim)

    state = env.reset()

    while step_idx < max_steps:
        episode_timesteps += 1
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        done_bool = float(done) if episode_timesteps < max_episode_steps else 0
        replay_buffer.add(state, action, next_state, reward, done_bool)
        state = next_state
        episode_reward += reward

        if step_idx >= start_timesteps:
            policy.train(replay_buffer, batch_size)

        if step_idx >= start_timesteps and step_idx % 10000 == 0:
            torch.save(policy.critic.state_dict(), 'transaction_MA_twogoal_12.pkl')

        if done:
            state = env.reset()
            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1
        step_idx += 1


if __name__ == '__main__':
    main()
