import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import common.grid_world as grid_world
import time
import numpy as np
from matplotlib import animation
import matplotlib.pyplot as plt
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Hyperparameters
learning_rate = 0.0005
gamma = 0.99
lmbda = 0.95
eps_clip = 0.1
K_epoch = 2

sequece_length = 64
num_units = 32


def save_frames_as_gif(frames, path='ppo_lstm/', filename='gym_animation.gif'):
    plt.figure(figsize=(frames[0].shape[1] / 36.0, frames[0].shape[0] / 36.0), dpi=36)

    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    # anim.save(path + filename, writer='imagemagick', fps=60)
    anim.save(path + filename, writer='pillow', fps=60)


class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.data = []
        self.action_dim = 4
        self.sequece_length = sequece_length
        self.num_units = num_units

        self.fc1 = nn.Linear(4, self.sequece_length)
        self.lstm = nn.LSTM(self.sequece_length, self.num_units)
        self.fc_pi = nn.Linear(self.num_units, self.action_dim)
        self.fc_v = nn.Linear(self.num_units, 1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, hidden):
        x = F.relu(self.fc1(x))
        x = x.view(-1, 1, self.sequece_length)
        x, lstm_hidden = self.lstm(x, hidden)
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=2)
        return prob, lstm_hidden

    def v(self, x, hidden):
        x = F.relu(self.fc1(x))
        x = x.view(-1, 1, self.sequece_length)
        x, lstm_hidden = self.lstm(x, hidden)
        v = self.fc_v(x)
        return v

    def put_data(self, transition):
        self.data.append(transition)

    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, h_in_lst, h_out_lst, done_lst = [], [], [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, h_in, h_out, done = transition

            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            h_in_lst.append(h_in)
            h_out_lst.append(h_out)
            done_mask = 0 if done else 1
            done_lst.append([done_mask])

        s, a, r, s_prime, done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                              torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
                                              torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
        self.data = []
        return s, a, r, s_prime, done_mask, prob_a, h_in_lst[0], h_out_lst[0]

    def sample_action(self, state, h_in):
        prob, h_out = self.pi(torch.from_numpy(state).float(), h_in)
        prob = prob.view(-1)
        m = Categorical(prob)
        a = m.sample().item()
        # train_prob = prob.detach().numpy()
        # train_prob=train_prob[0]
        # prediction = train_prob / np.sum(train_prob)
        # a = np.random.choice(list(range(self.action_dim)), p=prediction[0])
        return a, prob, h_out

    def test_action(self, state, h_in):
        prob, h_out = self.pi(torch.from_numpy(state).float(), h_in)
        test_prob = prob.detach().numpy()
        # print(test_prob)
        a = np.argmax(test_prob)
        return a, h_out

    def train_net(self):
        s, a, r, s_prime, done_mask, prob_a, (h1_in, h2_in), (h1_out, h2_out) = self.make_batch()
        first_hidden = (h1_in.detach(), h2_in.detach())
        second_hidden = (h1_out.detach(), h2_out.detach())

        for i in range(K_epoch):
            v_prime = self.v(s_prime, second_hidden).squeeze(1)
            td_target = r + gamma * v_prime * done_mask
            v_s = self.v(s, first_hidden).squeeze(1)
            delta = td_target - v_s
            delta = delta.detach().numpy()

            advantage_lst = []
            advantage = 0.0
            for item in delta[::-1]:
                advantage = gamma * lmbda * advantage + item[0]
                advantage_lst.append([advantage])
            advantage_lst.reverse()
            advantage = torch.tensor(advantage_lst, dtype=torch.float)

            pi, _ = self.pi(s, first_hidden)
            pi_a = pi.squeeze(1).gather(1, a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == log(exp(a)-exp(b))

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(v_s, td_target.detach())

            self.optimizer.zero_grad()
            loss.mean().backward(retain_graph=True)
            self.optimizer.step()


def main(mode):
    env = grid_world.TwoWayGridWorld()
    env.seed(0)
    PATH = 'ppo_lstm/' + 'sequece_length_' + str(sequece_length) + "_lstm_hid_" + str(
        num_units) + '.pth'
    model = PPO()
    max_length = 30
    remark_reward = []
    remark_reward.append(-100000)
    if mode == 'Train':
        # env = gym.make('CartPole-v1')
        # model = PPO()
        score = 0.0
        print_interval = 20
        num_episodes = 1000000000

        for n_epi in range(num_episodes):
            h_out = (
            torch.zeros([1, 1, num_units], dtype=torch.float), torch.zeros([1, 1, num_units], dtype=torch.float))
            state = env.reset()
            s = [e / env.n_height for e in state]
            s = np.array(s)  # list to ndarry
            done = False
            eps = 0
            if n_epi % 100 == 0 and n_epi != 0:
                ep_length = 0
                done = False
                while not (done or ep_length == max_length):
                    h_in = h_out
                    ep_length += 1
                    a, h_out = model.test_action(s, h_in)
                    next_state, reward, done, _ = env.step(a)
                    next_state = [e / env.n_height for e in next_state]
                    env.render()
                    s = np.array(next_state)
                # env.close()

            else:
                while not done:
                    for t in range(max_length):
                        h_in = h_out
                        a, prob, h_out = model.sample_action(s, h_in)
                        next_state, r, done, info = env.step(a)
                        # print(t,r)
                        next_state = [e / env.n_height for e in next_state]
                        model.put_data((s, a, r / 100.0, next_state, prob[a].item(), h_in, h_out, done))
                        s = np.array(next_state)

                        score += r
                        # env.render()
                        if done:
                            print('finish')
                            break
                    model.train_net()
                    done = True
                    # eps +=1
                    # if eps>50:
                    #     print('wrong')
                    #     break
                print(f'------- Epoch {n_epi} ---------')
                # if t%100 == 0 and t!=0:

                if n_epi % print_interval == 0 and n_epi != 0:
                    ave_reward = score / print_interval
                    print("# of episode :{}, avg score : {:.1f}".format(n_epi, ave_reward))
                    score = 0.0
                    if remark_reward[-1] < ave_reward:
                        # print("# of episode :{}, Reward : {:.1f}".format(n_epi, score))
                        torch.save(model.state_dict(), PATH)
                        remark_reward.append(ave_reward)
                        print('saved the model')
        env.close()
    else:
        print('test')
        frames = []
        test_episodes = 10
        model.load_state_dict(torch.load(PATH))
        model.eval()
        for ep_cur in range(test_episodes):
            h_out = (
            torch.zeros([1, 1, num_units], dtype=torch.float), torch.zeros([1, 1, num_units], dtype=torch.float))
            # h_in=None
            state = env.reset()
            s = [e / env.n_height for e in state]
            s = np.array(s)  # list to ndarry
            ep_length = 0
            done = False
            while not (done or ep_length == max_length):
                ep_length += 1
                h_in = h_out
                # a, prob, h_out = model.sample_action(s, h_in)
                ##强制走下门
                if ep_length <= 3:
                    a = 1
                elif ep_length <= 6:
                    a = 2
                else:
                    a, h_out = model.test_action(s, h_in)
                next_state, reward, done, _ = env.step(a)
                next_state = [e / env.n_height for e in next_state]
                env.render()
                frames.append(env.render(mode='rgb_array'))
                s = np.array(next_state)
        gif_name = 'ppo_lstm.gif'
        save_frames_as_gif(frames, filename=gif_name)


if __name__ == '__main__':
    main(mode='Train1')
