import torch
import numpy as np
import time, os, sys
from copy import deepcopy
# from pprint import pprint
import envs
import model
import algo
import runner
import utils

print (' '.join(sys.argv))

import argparse

parser = argparse.ArgumentParser()

# Misc
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--save', type=str, default="")
parser.add_argument('--gpu', action='store_true', default=False)

# MDP definition
parser.add_argument('--env', type=str, default='soccer')
parser.add_argument('--T', type=int, default=50, help='maximum episode length')
parser.add_argument('--gamma', type=float, default=0.97, help='discount factor')

# Meta-level algorithm
parser.add_argument('--niter', type=int, default=50, help='num of outer iterations')
parser.add_argument('--ninner', type=int, default=10, help='num of inner gd steps')
parser.add_argument('--method', type=str, default='base', choices=['base', 'baserand', 'basebest', 'const', 'constbest'])
parser.add_argument('--nagent', type=int, default=1)
parser.add_argument('--batch', type=int, default=32, help='niter*ninner*batch = total num episodes')
parser.add_argument('--nproc', type=int, default=1)

parser.add_argument('--test_T', type=int, default=100)
parser.add_argument('--test_batch', type=int, default=100)
parser.add_argument('--test_gamma', type=float, default=1.0)

parser.add_argument('--tabular', action='store_true', default=False)
parser.add_argument('--delta', type=float, default=0.0)

# Algorithm parameters
parser.add_argument('--algo', type=str, default='a2c', choices=['a2c', 'ppo'])
parser.add_argument('--opt', type=str, default='rmsprop', choices=['rmsprop', 'adam'])
parser.add_argument('--eps', type=float, default=1e-5, help='rmsprop eps')
parser.add_argument('--alpha', type=float, default=0.99, help='rmsprop alpha, adam beta2')
parser.add_argument('--beta', type=float, default=0.5, help='adam beta1 (momentum)')
parser.add_argument('--gae', action='store_true', default=True)
parser.add_argument('--gae_lambda', type=float, default=0.95)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--value_coef', type=float, default=0.5)
parser.add_argument('--entropy_coef', type=float, default=0.01)
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--anneal', type=int, default=None)

# not used{
parser.add_argument('--c_eg', action='store_true', default=False, help='extra gradient method')
parser.add_argument('--c_adalr', action='store_true', default=False, help='adaptive lr based on duality gap')
parser.add_argument('--target_gap', type=float, default=0.3)
# }
parser.add_argument('--no_pretrain', action='store_false', dest='pretrain', default=True)
parser.add_argument('--freeze', type=int, default=0, help='freeze the first n layers')
parser.add_argument('--shared', type=int, default=1, help='share the policy and value backbone')

args = parser.parse_args()

if len(args.save):
    if args.save[-1] == '!':
        args.save = args.save[:-1]
    elif os.path.exists(args.save):
        print ('path exist!')
        # exit(-1)

    args.save = utils.make_dir_exists(args.save)
    utils.setup_log(args.save + 'log.txt')
    print = utils.print

# torch.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)
print (args)

# compute cumulative returns
def evaluate_values(states, model):
    states = torch.as_tensor(states)
    shape = states.shape
    states = states.view(-1, shape[-1])
    value_preds = model(states, av=2)
    return value_preds.view(shape[:-1])


if args.gae:
    @torch.no_grad()
    def compute_returns(states, returns, masks, model):
        value_preds = evaluate_values(states, model)
        gae = 0
        for step in reversed(range(returns.shape[0])):
            delta = returns[step] + args.gamma * value_preds[step + 1] * masks[step + 1] - value_preds[step]
            gae = delta + args.gamma * args.gae_lambda * masks[step + 1] * gae
            returns[step] = gae + value_preds[step]
        return value_preds
else:
    @torch.no_grad()
    def compute_returns(states, returns, masks, model):
        next_value = evaluate_values(states[-1], model)
        next_value = value_preds[-1]
        for step in reversed(range(returns.shape[0])):
            next_value = returns[step] = next_value * args.gamma * masks[step + 1] + returns[step]
        return value_preds

class Meter:
    def __init__(self, alpha=0):
        self.mean = np.nan
        self.std = np.nan
        self.alpha = alpha
    def update(self, numbers):
        if not np.isnan(self.mean):
            self.mean = self.alpha * self.mean + (1-self.alpha) * np.mean(numbers)
            self.std = self.alpha * self.std + (1-self.alpha) * np.std(numbers)
        else:
            self.mean = np.mean(numbers)
            self.std = np.std(numbers)
    def __repr__(self):
        return "%7.4f(%7.4f)" % (self.mean, self.std)


device = torch.device("cuda" if args.gpu else "cpu")

if args.env == 'soccer':
    # env = [envs.SoccerEnv(0) for _ in range(args.nproc)]
    env = envs.SoccerEnv(0)
    state_dim = 5
    nact = 5

    from envs import SoccerBuiltinAgent
    builtin = SoccerBuiltinAgent()

    def get_act(x, player_a_or_b):
        if x == 'random':
            def act_func(state):
                return np.random.randint(nact)
        elif x == 'builtin':
            # builtin = SoccerBuiltinAgent(player_a_or_b, 0.5)
            def act_func(state):
                if state is None:
                    builtin.reset_agent(player_a_or_b, 0.5)
                else:
                    return builtin.act(state)
        else:
            def act_func(state):
                if state is not None:
                    return runner.act(x, state)
        return act_func

    if args.tabular:
        MODEL = model.TabularSoccer
    else:
        MODEL = model.MLPSoccer

    X = [ MODEL().to(device) for i in range(args.nagent) ]
    Y = [ MODEL().to(device) for i in range(args.nagent) ]

elif args.env == 'gomoku':
    assert not args.tabular

    env = envs.RenjuEnv()
    state_dim = 81
    nact = 81

    MODEL = model.FCN

    pretrain_x = MODEL(pretrain='./gomoku/k553_epoch0/0.pt', noise=False).to(device)

    def get_act(x, player_a_or_b):
        if x == 'random':
            def act_func(state):
                if state is not None:
                    valid = np.where(state == 0)[0]
                    return np.random.choice(valid)
                    # while True:
                    #     a = np.random.randint(nact)
                    #     if state[a] == 0:
                    #         return a
        elif x == 'builtin':
            def act_func(state):
                if state is not None:
                    return runner.act(pretrain_x, state, device=device)
        else:
            def act_func(state):
                if state is not None:
                    return runner.act(x, state, device=device)
        return act_func

    if args.pretrain:
        if args.nagent == 1:
            try:
                i0 = int(args.save.split('/')[-2])
            except:
                i0 = np.random.randint(8)
        else:
            try:
                i0 = int(args.save.split('/')[-2]) * args.nagent
            except:
                i0 = np.random.randint(8)
        X = [ MODEL(pretrain='./gomoku/k553_epoch0/%d.pt' % ((i0 + i)%8), freeze=args.freeze, shared=args.shared).to(device) 
                for i in range(args.nagent) ]
        Y = [ MODEL(pretrain='./gomoku/k553_epoch0/%d.pt' % ((8 + i0 + i)%16), freeze=args.freeze, shared=args.shared).to(device)
                for i in range(args.nagent) ]
    else:
        X = [ MODEL(freeze=args.freeze, shared=args.shared).to(device) for i in range(args.nagent) ]
        Y = [ MODEL(freeze=args.freeze, shared=args.shared).to(device) for i in range(args.nagent) ]


def test(x, y, gamma=args.test_gamma, s=""):
    act_x = get_act(x, player_a_or_b=0)
    act_y = get_act(y, player_a_or_b=1)
    rews = runner.test_rollout(env, args.test_T, args.test_batch, act_x, act_y, gamma)
    print (s + "\ttest rew %7.4f +/- %6.4f (win %.1f%% draw %.1f%% lose %.1f%%)" % (
        rews.mean(), rews.std(),
        (rews > 0).mean()*100, (rews == 0).mean()*100, (rews < 0).mean()*100)
    )



rlX = [ algo.A2C(x, args, compute_returns) for x in X ]
histX = []
tmpx = MODEL().to(device)
# tmpx = deepcopy(X[0])

rlY = [ algo.A2C(y, args, compute_returns) for y in Y ]
histY = []
tmpy = MODEL().to(device)
# tmpy = deepcopy(Y[0])

num_episodes = 0
hist_num_ep = []

def checkpoint():
    histX.append([ deepcopy(x.state_dict()) for x in X ])
    histY.append([ deepcopy(y.state_dict()) for y in Y ])
    hist_num_ep.append(num_episodes)
    if len(args.save):
        # with open(args.save + 'hist.pt', 'wb') as f:
        # torch.save((histX, histY), args.save + 'hist.pt')
        torch.save(histX, args.save + 'histX.pt')
        torch.save(histY, args.save + 'histY.pt')

checkpoint()

# sanity checks
test('builtin', 'random', s="built vs rand")
test('random', 'builtin', s="rand vs built")
test(X[0], Y[0], s="x vs y   ")
test(X[0], 'random', s="x vs rand")
test('random', Y[0], s="rand vs y")
test(X[0], 'builtin', s="x vs built")
test('builtin', Y[0], s="built vs y")


if args.method == 'base':
    assert args.nagent == 1

    x, y = X[0], Y[0]
    for it in range(args.niter):
        print ('\n', it)

        idx = -1
        tmpy.load_state_dict(histY[idx][0])
        print ("x: compete w/ y[it=%d]" % (idx))

        for inner in range(args.ninner):
            sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
            num_episodes += args.batch
            raw_ep_rew_x = rx.sum(0).mean().item()
            lx = mx.sum(0).mean().item()

            # compute_returns(sx, rx, mx, x)
            action_loss_x, value_loss_x, entropy_x, grad_norm_x = rlX[0].update(sx, ax, rx, mx)
            dis_ep_rew = rx[0].mean().item()

            print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                (action_loss_x, value_loss_x, entropy_x, grad_norm_x, raw_ep_rew_x, dis_ep_rew, lx))


        idx = -1
        tmpx.load_state_dict(histX[idx][0])
        print ("y: compete w/ x[it=%d]" % (idx))

        for inner in range(args.ninner):
            _, _, _, _,  sy, ay, ry, my = runner.rollout(
                env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
            num_episodes += args.batch
            raw_ep_rew_y = ry.sum(0).mean().item()
            ly = my.sum(0).mean().item()

            # compute_returns(sy, ry, my, y)
            action_loss_y, value_loss_y, entropy_y, grad_norm_y = rlY[0].update(sy, ay, ry, my)
            dis_ep_rew = ry[0].mean().item()

            print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                (action_loss_y, value_loss_y, entropy_y, grad_norm_y, raw_ep_rew_y, dis_ep_rew, ly))

        checkpoint()

        print ("episodes %d" % num_episodes)
        test(x, y, s="x vs y   ")
        test(x, 'random', s="x vs rand")
        test('random', y, s="rand vs y")
        test(x, 'builtin', s="x vs built")
        test('builtin', y, s="built vs y")


if args.method == 'baserand':
    assert args.nagent == 1

    x, y = X[0], Y[0]
    for it in range(args.niter):
        print ('\n', it)

        idx = np.random.randint(len(histY))
        tmpy.load_state_dict(histY[idx][0])
        print ("x: compete w/ y[it=%d]" % (idx))

        for inner in range(args.ninner):
            sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
            num_episodes += args.batch
            raw_ep_rew_x = rx.sum(0).mean().item()
            lx = mx.sum(0).mean().item()

            # compute_returns(sx, rx, mx, x)
            action_loss_x, value_loss_x, entropy_x, grad_norm_x = rlX[0].update(sx, ax, rx, mx)
            dis_ep_rew = rx[0].mean().item()

            print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                (action_loss_x, value_loss_x, entropy_x, grad_norm_x, raw_ep_rew_x, dis_ep_rew, lx))


        idx = np.random.randint(len(histY))
        tmpx.load_state_dict(histX[idx][0])
        print ("y: compete w/ x[it=%d]" % (idx))

        for inner in range(args.ninner):
            _, _, _, _,  sy, ay, ry, my = runner.rollout(
                env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
            num_episodes += args.batch
            raw_ep_rew_y = ry.sum(0).mean().item()
            ly = my.sum(0).mean().item()

            # compute_returns(sy, ry, my, y)
            action_loss_y, value_loss_y, entropy_y, grad_norm_y = rlY[0].update(sy, ay, ry, my)
            dis_ep_rew = ry[0].mean().item()

            print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                (action_loss_y, value_loss_y, entropy_y, grad_norm_y, raw_ep_rew_y, dis_ep_rew, ly))

        checkpoint()

        print ("episodes %d" % num_episodes)
        test(x, y, s="x vs y   ")
        test(x, 'random', s="x vs rand")
        test('random', y, s="rand vs y")
        test(x, 'builtin', s="x vs built")
        test('builtin', y, s="built vs y")



if args.method == 'basebest':
    assert args.nagent == 1

    x, y = X[0], Y[0]
    idx_best = 0

    for it in range(args.niter):
        print ('\n', it)

        tmpy.load_state_dict(histY[idx_best][0])
        for inner in range(args.ninner):
            sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
            num_episodes += args.batch
            raw_ep_rew_x = rx.sum(0).mean().item()
            lx = mx.sum(0).mean().item()

            # compute_returns(sx, rx, mx, x)
            action_loss_x, value_loss_x, entropy_x, grad_norm_x = rlX[0].update(sx, ax, rx, mx)
            dis_ep_rew = rx[0].mean().item()

            print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                (action_loss_x, value_loss_x, entropy_x, grad_norm_x, raw_ep_rew_x, dis_ep_rew, lx))


        tmpx.load_state_dict(histX[idx_best][0])
        for inner in range(args.ninner):
            _, _, _, _,  sy, ay, ry, my = runner.rollout(
                env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
            num_episodes += args.batch
            raw_ep_rew_y = ry.sum(0).mean().item()
            ly = my.sum(0).mean().item()

            # compute_returns(sy, ry, my, y)
            action_loss_y, value_loss_y, entropy_y, grad_norm_y = rlY[0].update(sy, ay, ry, my)
            dis_ep_rew = ry[0].mean().item()

            print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                (action_loss_y, value_loss_y, entropy_y, grad_norm_y, raw_ep_rew_y, dis_ep_rew, ly))

        r = 0.5 * (raw_ep_rew_x + raw_ep_rew_y)
        if r > args.delta:
            print ('update best %d -> %d %7.4f(%7.4f, %7.4f)' % (idx_best, it + 1,
                r, raw_ep_rew_x, raw_ep_rew_y))
            idx_best = it + 1

        checkpoint()

        print ("episodes %d" % num_episodes)
        test(x, y, s="x vs y   ")
        test(x, 'random', s="x vs rand")
        test('random', y, s="rand vs y")
        test(x, 'builtin', s="x vs built")
        test('builtin', y, s="built vs y")


if args.method == 'const':
    # assert args.nagent > 1
    # if args.c_adalr:
    a_dual_gap = Meter(0.9)

    for it in range(args.niter):
        print ('\n', it)

        pairwise_rolls = np.empty((args.nagent, args.nagent), dtype=np.object)
        pairwise_r = np.zeros((args.nagent, args.nagent))
        for i in range(args.nagent):
            for j in range(args.nagent):
                # x_i vs y_j
                roll = runner.rollout(env, args.T, args.batch, state_dim, X[i], Y[j], record=3, device=device)
                num_episodes += args.batch
                pairwise_rolls[i, j] = roll
                pairwise_r[i, j] = roll[2].sum(0).mean()
        #pprint (pairwise_r)
        print (pairwise_r)

        # # extra-gradient method, rescind one step if the current agent is chosen as opponent
        # if args.c_eg and it > 0:
        #     for i in range(args.nagent):
        #         X[i].load_state_dict(histX[-2][i])
        #         Y[i].load_state_dict(histY[-2][i])

        for i in range(args.nagent):
            x = X[i]

            min_j = np.argmin(pairwise_r[i, :])
            min_r = pairwise_r[i, min_j]
            # sx, ax, rx, mx = pairwise_rolls[i, min_j][:4]
            sx, ax, rx, mx = [torch.cat([b[_] for b in pairwise_rolls[i, :]], 1) for _ in [0,1,2,3]]
            tmpy.load_state_dict(histY[-1][min_j])

            # extra-gradient method, rescind one step if the current agent is chosen as opponent
            if args.c_eg and i == min_j and it > 0:
                X[i].load_state_dict(histX[-2][i])
                print ("x%d: compete w/ y%d [rew %7.4f] rescind" % (i, min_j, min_r))
            else:
                print ("x%d: compete w/ y%d [rew %7.4f]" % (i, min_j, min_r))

            max_j = np.argmax(pairwise_r[:, i])
            max_r = pairwise_r[max_j, i]
            # sy, ay, ry, my = pairwise_rolls[max_j, i][4:8]
            sy, ay, ry, my = [torch.cat([b[_] for b in pairwise_rolls[:, i]], 1) for _ in [4,5,6,7]]
            tmpx.load_state_dict(histX[-1][max_j])

            # extra-gradient method, rescind one step if the current agent is chosen as opponent
            if args.c_eg and i == max_j and it > 0:
                Y[i].load_state_dict(histY[-2][i])
                print ("y%d: compete w/ x%d [rew %7.4f] rescind" % (i, max_j, max_r))
            else:
                print ("y%d: compete w/ x%d [rew %7.4f]" % (i, max_j, max_r))

            dual_gap = max_r - min_r
            a_dual_gap.update(dual_gap)
            if args.c_adalr:
                # lr = args.lr * np.clip(dual_gap / (a_dual_gap.mean + 1e-8), 1.0, 1.2)
                lr = args.lr * np.clip(dual_gap / args.target_gap, 1.0, 1.5)
                rlX[i].set_lr(lr)
                rlY[i].set_lr(lr)
                print ("dual gap: %7.4f (a %7.4f), lr set to %.5f" % (dual_gap, a_dual_gap.mean, lr))
            else:
                print ("dual gap: %7.4f (a %7.4f)" % (dual_gap, a_dual_gap.mean))

            for inner in range(args.ninner):
                if inner > 0:
                    sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                        env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
                    num_episodes += args.batch
                raw_ep_rew_x = rx.sum(0).mean().item()
                lx = mx.sum(0).mean().item()

                # compute_returns(sx, rx, mx, x)
                action_loss_x, value_loss_x, entropy_x, grad_norm_x = rlX[i].update(sx, ax, rx, mx, k=args.nagent if inner==0 else 1)
                dis_ep_rew = rx[0].mean().item()

                print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                    (action_loss_x, value_loss_x, entropy_x, grad_norm_x, raw_ep_rew_x, dis_ep_rew, lx))


            y = Y[i]

            for inner in range(args.ninner):
                if inner > 0:
                    _, _, _, _, sy, ay, ry, my = runner.rollout(
                        env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
                    num_episodes += args.batch

                raw_ep_rew_y = ry.sum(0).mean().item()
                ly = my.sum(0).mean().item()

                # compute_returns(sy, ry, my, y)
                action_loss_y, value_loss_y, entropy_y, grad_norm_y = rlY[i].update(sy, ay, ry, my, k=args.nagent if inner==0 else 1)
                dis_ep_rew = ry[0].mean().item()

                print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f\td_rew %7.4f\tlen %.2f" % 
                    (action_loss_y, value_loss_y, entropy_y, grad_norm_y, raw_ep_rew_y, dis_ep_rew, ly))

            test(x, y, s="x vs y   ")
            test(x, 'random', s="x vs rand")
            test('random', y, s="rand vs y")
            test(x, 'builtin', s="x vs built")
            test('builtin', y, s="built vs y")

        checkpoint()

        print ("episodes %d" % num_episodes)


if len(args.save):
    utils.close_log()
