import PPO
from PPO import collect_trajectories, expert_collect_trajectories, testPPO
from utils import *
import torch
import cvxpy as cp
import numpy as np
import pandas as pd
from multiprocessing import Process, Pipe
from datetime import datetime
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import random

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
USE_MINI_MAP = False
state_num = 36 if USE_MINI_MAP else 132

# child process
# calculate x^{k+1}_i and send to the master to calculate xbar, then receives xbar, loop
def run_worker(occupancy_expert, occupancy_agent, preference, pipe):
    n = state_num * 4 * 2
    rho = 5
    x = cp.Variable(n)
    xbar = cp.Parameter(n, value=np.zeros(n))
    u = cp.Parameter(n, value=np.zeros(n))
    f_value = preference @ cp.sum(cp.reshape(cp.multiply((occupancy_agent - occupancy_expert).unsqueeze(2).repeat(1, 1, 2).reshape(-1), x),
                                             (2, state_num * 4)).T, 0) + (rho / 2) * cp.sum_squares(x - xbar + u)
    prox = cp.Problem(cp.Minimize(f_value))

    # ADMM loop.
    while True:
        prox.solve(solver=cp.MOSEK)
        pipe.send(x.value)
        xbar.value = pipe.recv()
        u.value += x.value - xbar.value


if __name__ == "__main__":
    path = os.getcwd() + '/checkpoints/' + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    if not os.path.exists(path):
        os.makedirs(path)
        os.makedirs(path+'/reward_signals')
        os.makedirs(path+'/testing_logs')
        os.makedirs(path+'/agents')
    dev = torch.device('cpu')
    # how many rounds of (PPO -> ADMM)
    rounds = 1
    numOfExperts = 3
    # initialize reward_signal
    reward_signal = torch.rand(state_num * 4 * 2).to(dev)
    preferences = torch.FloatTensor([[0.1, 0.9], [0.9, 0.1], [0.5, 0.5]])
    ppo_experts = []
    trajectories_experts = []
    occupancy_experts = []
    # run (PPO -> ADMM)
    x, y = [[] for i in range(numOfExperts)], [[] for i in range(numOfExperts)]
    rewards = [[] for i in range(numOfExperts)]
    lengths = [[] for i in range(numOfExperts)]
    for k in range(rounds):
        print("Round", k)
        ppo_agents = []
        trajectories_agents = []
        occupancy_agents = []
        for i in range(numOfExperts):
            # only collect expert's trajectories at first loop
            if k == 0 :
                # collect expert's trajectories for 1000 episodes
                trajectories_experts.append(expert_collect_trajectories(preferences[i], 500))
                # calculate expert's occupancy measure according to the collected trajectories
                occupancy_experts.append(occupancy(
                    trajectories_experts[i].states,
                    trajectories_experts[i].actions,
                    trajectories_experts[i].is_terminals,
                    1.0,
                    100,
                    state_num,
                    4,
                    0.4,
                ))

            # train agent(with reward_signal that computed with ADMM)
            ppo_agents.append(PPO.runPPO(reward_signal, preferences[i], i, True))
            torch.save(ppo_agents[-1].policy.state_dict(), path + '/agents/agent' + str(i) + '_' + str(k) + '.pth')
            # collect agent's trajectories for 500 episodes
            trajectories_agents.append(collect_trajectories(ppo_agents[i], 500))
            # calculate agent's occupancy measure according to the collected trajectories
            occupancy_agents.append(occupancy(
                trajectories_agents[i].states,
                trajectories_agents[i].actions,
                trajectories_agents[i].is_terminals,
                1.0,
                100,
                state_num,
                4,
                0.4,
            ))

        # run ADMM

        # Setup the workers.
        pipes = []
        procs = []
        for i in range(numOfExperts):
            local, remote = Pipe()
            pipes += [local]
            procs += [Process(
                target=run_worker,
                args=(
                    occupancy_experts[i],
                    occupancy_agents[i],
                    preferences[i],
                    remote,
                ),
            )]
            procs[-1].start()

        # master process
        # ADMM loop.
        MAX_ITER = 300
        for i in range(MAX_ITER):
            # Gather and average xi
            if i > 0:
                pre_xbar = xbar

            norm_x = [pipe.recv() for pipe in pipes]
            xbar = sum(norm_x) / numOfExperts
            diff = 1
            if i > 0:
                # print(xbar - pre_xbar)
                diff = np.linalg.norm(xbar - pre_xbar)
                #print("Norm at i =", i, ":", diff)

            # Scatter xbar
            for pipe in pipes:
                pipe.send(xbar)
            if diff < 0.001:
                break

        # update reward signal with xbar
        old_reward_signal = reward_signal
        reward_signal = torch.FloatTensor(sum([pipe.recv() for pipe in pipes]) / numOfExperts).to(dev)
        [p.terminate() for p in procs]
        diff_occupancy = torch.zeros((state_num, 4), dtype=torch.float).to(dev)
        for i in range(numOfExperts):
            diff_occupancy += (occupancy_experts[i] - occupancy_agents[i]).to(dev)
        diff_occupancy = diff_occupancy.reshape(-1,1)
        reward_signal = mean_reward(old_reward_signal, reward_signal, k+1)
        torch.save(old_reward_signal.reshape(state_num, 4, 2), path + '/reward_signals/reward_signal' + str(k) + '.pt')
        for i in range(numOfExperts):
            norm = torch.norm(occupancy_experts[i] - occupancy_agents[i], p=1).item()
            x[i].append(k)
            y[i].append(norm)
        for i in range(numOfExperts):
            rew, lang = testPPO(preferences[i], i, False, False, path + '/agents/agent' + str(i) + '_' + str(k) + '.pth')
            rewards[i].append(rew)
            lengths[i].append(lang)
    for i in range(numOfExperts):
        pd_list = [[rewards[i][j], lengths[i][j]] for j in range(rounds)]
        df = pd.DataFrame(pd_list, columns=['Reward', 'Length'])
        df.to_csv(path+'/testing_logs/agent'+str(i)+'.csv')

    colors = ["red", "blue", "yellow"]
    plt.figure("L1 norm")
    for i in range(numOfExperts):
        plt.plot(x[i], y[i], color=colors[i], linestyle="-", linewidth="2", markersize="10", marker=".")
    plt.xlabel("round", fontsize="10")
    plt.ylabel("L1 norm", fontsize="10")
    plt.savefig("9 norms25.png")

    plt.figure("Rewards")
    for i in range(numOfExperts):
        plt.plot(x[i], rewards[i], color=colors[i], linestyle="-", linewidth="2", markersize="10", marker=".")
    plt.xlabel("round", fontsize="10")
    plt.ylabel("reward", fontsize="10")
    plt.savefig("9 rewards25.png")

    plt.figure("Lengths")
    for i in range(numOfExperts):
        plt.plot(x[i], lengths[i], color=colors[i], linestyle="-", linewidth="2", markersize="10", marker=".")
    plt.xlabel("round", fontsize="10")
    plt.ylabel("length", fontsize="10")
    plt.savefig("9 lengths25.png")

