import random
import pickle
from datetime import datetime

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
now_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')


def save_variable(v, filename):
    f = open(filename, 'wb')
    pickle.dump(v, f)
    f.close()
    return filename


class GoalFixWrapper(gym.Wrapper):
    def reset(self):
        self.env.reset()
        self.sim.reset()
        ob = self.reset_model()
        return ob

    def reset_model(self):
        qpos = self.init_qpos
        while True:
            self.goal = np.array([0.0202, 0.1109])
            if np.linalg.norm(self.goal) < 0.2:
                break
        qpos[-2:] = self.goal
        qvel = self.init_qvel
        qvel[-2:] = 0
        self.set_state(qpos, qvel)
        return self._get_obs()

    def _get_obs(self):
        theta = self.sim.data.qpos.flat[:2]
        return np.concatenate([
            np.cos(theta),
            np.sin(theta),
            self.sim.data.qpos.flat[2:],
            self.sim.data.qvel.flat[:2],
            self.get_body_com("fingertip") - self.get_body_com("target")
        ])


class RewardShapeWrapper(gym.Wrapper):
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        if done == False:
            reward = 0
        else:
            reward = info['reward_dist']
        return observation, reward, done, info


class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)


class NormalizedActions(gym.ActionWrapper):
    def action(self, action):
        low = self.action_space.low
        high = self.action_space.high
        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)
        return action

    def reverse_action(self, action):
        low = self.action_space.low
        high = self.action_space.high
        action = 2 * (action - low) / (high - low) - 1
        action = np.clip(action, low, high)
        return action


class Transaction(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Transaction, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 256)
        self.l4 = nn.Linear(256, state_dim)

    def forward(self, state, action):
        sa = torch.cat([state, action], -1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = F.relu(self.l3(q1))
        q1 = self.l4(q1)
        return q1


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(Critic, self).__init__()
        # Edge flow network architecture
        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], -1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = F.softplus(self.l3(q1))
        return q1


class CFN(object):
    def __init__(self, state_dim, action_dim, hidden_dim, max_action, uniform_action_size,
                 discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
        self.critic = Critic(state_dim, action_dim, hidden_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        self.transaction = Transaction(state_dim, action_dim).to(device)
        self.transaction_optimizer = torch.optim.Adam(self.transaction.parameters(), lr=3e-5)
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.uniform_action_size = uniform_action_size
        self.uniform_action = np.random.uniform(low=-max_action, high=max_action, size=(uniform_action_size, 2))
        self.uniform_action = torch.Tensor(self.uniform_action).to(device)
        self.total_it = 0
        self.action_dim = action_dim
        self.state_dim = state_dim

    def select_action(self, state, is_max):
        sample_action = np.random.uniform(low=-self.max_action, high=self.max_action, size=(10000, 2))
        with torch.no_grad():
            sample_action = torch.Tensor(sample_action).to(device)
            state = torch.FloatTensor(state.reshape(1, -1)).repeat(10000, 1).to(device)
            edge_flow = self.critic(state, sample_action).reshape(-1)
            if is_max == 0:
                idx = Categorical(edge_flow.float()).sample(torch.Size([1]))
                action = sample_action[idx[0]]
            elif is_max == 1:
                action = sample_action[edge_flow.argmax()]
        return action.cpu().data.numpy().flatten()

    def set_uniform_action(self):
        self.uniform_action = np.random.uniform(low=-self.max_action, high=self.max_action,
                                                size=(self.uniform_action_size, 2))
        self.uniform_action = torch.Tensor(self.uniform_action).to(device)
        return self.uniform_action

    def train(self, replay_buffer, frame_idx, done_true, batch_size=256, max_episode_steps=50, sample_flow_num=100):
        # Sample replay buffer
        state, action, reward, next_state, not_done = replay_buffer.sample(batch_size)
        state = torch.FloatTensor(state).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).to(device)
        not_done = torch.FloatTensor(np.float32(not_done)).to(device)

        with torch.no_grad():
            uniform_action = np.random.uniform(low=-self.max_action, high=self.max_action,
                                               size=(batch_size, max_episode_steps, sample_flow_num, 2))
            uniform_action = torch.Tensor(uniform_action).to(device)
            current_state = next_state.repeat(1, 1, sample_flow_num).reshape(batch_size, max_episode_steps,
                                                                             sample_flow_num, -1)
            inflow_state = self.transaction(current_state, uniform_action)
            inflow_state = torch.cat([inflow_state, state.reshape(
                batch_size, max_episode_steps, -1, self.state_dim)], -2)
            uniform_action = torch.cat([uniform_action, action.reshape(
                batch_size, max_episode_steps, -1, self.action_dim)], -2)
        edge_inflow = self.critic(inflow_state, uniform_action).reshape(batch_size, max_episode_steps, -1)

        epi = torch.Tensor([1.0]).repeat(batch_size * max_episode_steps).reshape(batch_size, -1).to(device)
        inflow = torch.log(torch.sum(torch.exp(torch.log(edge_inflow)), -1) + epi)

        with torch.no_grad():
            uniform_action = np.random.uniform(low=-self.max_action, high=self.max_action,
                                               size=(batch_size, max_episode_steps, sample_flow_num, self.action_dim))
            uniform_action = torch.Tensor(uniform_action).to(device)
            outflow_state = next_state.repeat(1, 1, (sample_flow_num + 1)).reshape(batch_size, max_episode_steps,
                                                                                   (sample_flow_num + 1), -1)
            last_action = torch.Tensor([0.0, 0.0]).reshape([1, 1, self.action_dim]).repeat(batch_size, 1, 1).to(device)
            last_action = torch.cat([action[:, 1:, :], last_action], -2)
            uniform_action = torch.cat(
                [uniform_action, last_action.reshape(batch_size, max_episode_steps, -1, self.action_dim)], -2)

        edge_outflow = self.critic(outflow_state, uniform_action).reshape(batch_size, max_episode_steps, -1)
        outflow = torch.log(torch.sum(torch.exp(torch.log(edge_outflow)), -1) + epi)
        critic_loss = F.mse_loss(inflow * not_done, outflow * not_done, reduction='none') + F.mse_loss(
            inflow * done_true, (torch.cat(
                [reward[:, :-1], torch.log(((reward * (sample_flow_num + 1)) + epi)[:, -1]).reshape(batch_size, -1)],
                -1)) * done_true, reduction='none')

        critic_loss = torch.mean(torch.sum(critic_loss, dim=1))
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if frame_idx % 5 == 0:
            pre_state = self.transaction(next_state, action)
            transaction_loss = F.mse_loss(pre_state, state)
            # Optimize the critic
            self.transaction_optimizer.zero_grad()
            transaction_loss.backward()
            self.transaction_optimizer.step()


def reward_shaping(reward):
    low = -1.0
    high = 0.0
    reward = (reward - low) / (high - low)
    return reward


def main():
    writer = SummaryWriter(log_dir="runs/gfn_reacher_newloss_" + now_time)

    # --------  define env --------
    env = RewardShapeWrapper(gym.make('Reacher-v2'))
    test_env = RewardShapeWrapper(gym.make('Reacher-v2'))
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]
    max_action = float(env.action_space.high[0])
    hidden_dim = 256
    uniform_action_size = 2000

    # --------  define policy --------
    policy = CFN(state_dim, action_dim, hidden_dim, max_action, uniform_action_size)
    policy.transaction.load_state_dict(torch.load('transaction_reacher_sparse.pkl'))

    # --------  define replay_buffer --------
    replay_buffer_size = 2000
    replay_buffer = ReplayBuffer(replay_buffer_size)

    max_frames = 2000
    start_timesteps = 150
    frame_idx = 0
    rewards = []
    batch_size = 128
    test_epoch = 0
    sample_flow_num = 99
    repeat_episode_num = 5
    sample_episode_num = 1000
    max_episode_steps = 50

    done_true = torch.zeros(batch_size, max_episode_steps).to(device)
    for i in done_true:
        i[max_episode_steps - 1] = 1.0

    while frame_idx < max_frames:
        state = env.reset()
        episode_reward = 0
        state_buf = []
        action_buf = []
        reward_buf = []
        next_state_buf = []
        done_buf = []
        for step in range(max_episode_steps):
            with torch.no_grad():
                action = policy.select_action(state, 0)

            next_state, reward, done, _ = env.step(action)
            done_bool = float(1. - done)
            state_buf.append(state)
            action_buf.append(action)
            reward_buf.append(reward_shaping(reward))
            next_state_buf.append(next_state)
            done_buf.append(done_bool)
            state = next_state
            episode_reward += reward

            if done:
                frame_idx += 1
                replay_buffer.push(state_buf, action_buf, reward_buf, next_state_buf, done_buf)
                break

            if frame_idx >= start_timesteps and step % 7 == 0:
                policy.train(replay_buffer, frame_idx, done_true, batch_size, max_episode_steps, sample_flow_num)

        if frame_idx >= start_timesteps and frame_idx % 10 == 0:
            test_epoch += 1
            avg_test_episode_reward = 0
            for i in range(repeat_episode_num):
                test_state = test_env.reset()
                test_episode_reward = 0
                for s in range(max_episode_steps):
                    test_action = policy.select_action(np.array(test_state), 1)
                    test_next_state, test_reward, test_done, _ = test_env.step(test_action)
                    test_state = test_next_state
                    test_episode_reward += test_reward
                    if test_done:
                        break
                avg_test_episode_reward += test_episode_reward

            torch.save(policy.critic.state_dict(), "runs/gfn_reacher_all_" + now_time + '.pkl')
            writer.add_scalar("gfn_reacher_max_reward", avg_test_episode_reward / repeat_episode_num,
                              global_step=frame_idx * max_episode_steps)

            total_state_buf = []
            total_reward_buf = []
            for i in range(sample_episode_num):
                test_state = test_env.reset()
                test_state_buf = []
                for step in range(max_episode_steps):
                    with torch.no_grad():
                        action = policy.select_action(np.array(test_state), 0)
                    next_test_state, reward, done, _ = test_env.step(action)
                    test_state_buf.append(test_state)
                    test_state = next_test_state

                    if done:
                        total_state_buf.append(test_state_buf)
                        total_reward_buf.append(reward)
                        break
            save_variable(total_state_buf, 'gfn_reacher_state_ceta_' + str(frame_idx) + '_1000.data')
            save_variable(total_reward_buf, 'gfn_reacher_reward_ceta_' + str(frame_idx) + '_1000.data')
        rewards.append(episode_reward)


def test():
    writer = SummaryWriter(log_dir="runs/gfn_reacher_newloss_" + now_time)
    max_episode_steps = 50
    env = RewardShapeWrapper(gym.make('Reacher-v2'))
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]
    max_action = float(env.action_space.high[0])
    hidden_dim = 256
    uniform_action_size = 2000
    policy = CFN(state_dim, action_dim, hidden_dim, max_action, uniform_action_size)
    policy.transaction.load_state_dict(torch.load('transaction_reacher_sparse.pkl'))
    frame_idx = 0
    env = GoalFixWrapper(gym.make('Reacher-v2'))
    policy.critic.load_state_dict(torch.load('gfn_normal_reacher_all.pkl'))
    test_episode_num = 100
    all_episode_state = []
    all_episode_position = []
    all_reward = []

    while frame_idx < test_episode_num:
        state = env.reset()
        episode_reward = 0

        state_buf = []
        position_buf = []
        action_buf = []
        reward_buf = []
        next_state_buf = []
        done_buf = []

        for step in range(max_episode_steps):
            with torch.no_grad():
                action = policy.select_action(state, 0)

            next_state, reward, done, _ = env.step(action)
            done_bool = float(1. - done)

            state_buf.append(np.array(state))
            position_buf.append(np.array([state[-3], state[-2]]))
            action_buf.append(action)
            reward_buf.append(reward_shaping(reward))
            next_state_buf.append(next_state)
            done_buf.append(done_bool)

            state = next_state
            episode_reward += reward_shaping(reward)

            if done:
                frame_idx += 1
                all_reward.append(reward)
                all_episode_state.append(state_buf)
                all_episode_position.append(position_buf)

    filter_episode_state = []
    filter_episode_position = []
    for i in range(len(all_episode_state)):
        if reward[i] > -0.5:
            filter_episode_state.append(all_episode_state[i])
            filter_episode_position.append(all_episode_position[i])

    var = []
    final_position_var = []
    for i in range(test_episode_num):
        state_mse = []
        final_position_mse = []
        for j in range(test_episode_num):
            state_mse.append(
                ((np.array(filter_episode_state[i]) - np.array(filter_episode_state[j])) ** 2).mean(axis=1).mean(
                    axis=0))
            final_position_mse.append(
                ((np.array(filter_episode_position[i]) - np.array(filter_episode_position[j])) ** 2).mean(axis=0))
        var.append(np.var(state_mse))
        final_position_var.append(np.var(final_position_mse))
    writer.add_scalar("gfn_reacher_traj_max_var", np.mean(var), global_step=frame_idx * max_episode_steps)
    writer.add_scalar("gfn_reacher_final_position_max_var", np.mean(final_position_var),
                      global_step=frame_idx * max_episode_steps)


if __name__ == '__main__':
    main()
