import torch
import numpy as np
import time, os, sys
from copy import deepcopy
# from pprint import pprint
import envs
import model
import algo
import runner
import utils

if __name__ == '__main__':

    import argparse
    parser = argparse.ArgumentParser()

    # Misc
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--save', type=str, default="")
    parser.add_argument('--gpu', type=int, default=-1)

    # MDP definition
    parser.add_argument('--env', type=str, default='soccer')
    parser.add_argument('--T', type=int, default=50, help='maximum episode length')
    parser.add_argument('--gamma', type=float, default=0.97, help='discount factor')

    # Meta-level algorithm
    parser.add_argument('--niter', type=int, default=50, help='num of outer iterations')
    parser.add_argument('--ninner', type=int, default=10, help='num of inner gd steps')
    parser.add_argument('--method', type=str, default='base',
        choices=['singlex', 'singley', 'base', 'baserand', 'baserandl', 'basebest', 'const', 'constv'])
    parser.add_argument('--nagent', type=int, default=1)
    parser.add_argument('--batch', type=int, default=32, help='niter*ninner*batch = total num episodes')
    parser.add_argument('--nproc', type=int, default=1)

    parser.add_argument('--test_T', type=int, default=100)
    parser.add_argument('--test_batch', type=int, default=100)
    parser.add_argument('--test_gamma', type=float, default=1.0)

    # parser.add_argument('--model', type=str, default='tabular', choices=['tabular', 'nn'])
    parser.add_argument('--tabular', action='store_true', default=False)
    parser.add_argument('--delta', type=float, default=0.0)

    # Algorithm parameters
    parser.add_argument('--algo', type=str, default='a2c', choices=['a2c', 'ppo'])
    parser.add_argument('--opt', type=str, default='rmsprop', choices=['rmsprop', 'adam', 'sgd'])
    parser.add_argument('--eps', type=float, default=1e-5, help='rmsprop eps')
    parser.add_argument('--alpha', type=float, default=0.99, help='rmsprop alpha, adam beta2')
    parser.add_argument('--beta', type=float, default=0.5, help='adam beta1 (momentum)')
    parser.add_argument('--gae', action='store_true', default=True)
    parser.add_argument('--gae_lambda', type=float, default=0.95)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--anneal_lr', action='store_true')
    parser.add_argument('--value_coef', type=float, default=0.5)
    parser.add_argument('--value_lr_mult', type=float, default=1.0)
    parser.add_argument('--entropy_coef', type=float, default=0.01)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)
    parser.add_argument('--opt_epochs', type=int, default=1)
    parser.add_argument('--minibatch_size', type=int, default=128)
    parser.add_argument('--adv_norm', action='store_true')
    parser.add_argument('--clip', type=float, default=0.2)
    parser.add_argument('--sep_opt', action='store_true')
    parser.add_argument('--init_std', type=float, default=None)
    parser.add_argument('--fix_std', action='store_true')

    # parser.add_argument('--c_eg', action='store_true', default=False, help='extra gradient method')
    parser.add_argument('--c_adalr', action='store_true', default=False, help='adaptive lr based on duality gap')
    parser.add_argument('--target_gap', type=float, default=None)
    parser.add_argument('--no_pretrain', action='store_false', dest='pretrain', default=True)
    parser.add_argument('--freeze', type=int, default=0, help='freeze the first n layers')
    # parser.add_argument('--fix_x', action='store_true')
    # parser.add_argument('--fix_y', action='store_true')

    args = parser.parse_args()

    if len(args.save):
        if args.save[-1] == '!':
            args.save = args.save[:-1]
        elif os.path.exists(args.save):
            answer = utils.input_with_timeout('path exist! continue? [y/n, timeout 3s] ', 3)
            if answer == 'n':
                exit(-1)

        args.save = utils.make_dir_exists(args.save)
        utils.setup_log(args.save + 'log.txt')
        print = utils.print

    # torch.manual_seed(args.seed)
    # torch.cuda.manual_seed_all(args.seed)
    print (' '.join(sys.argv))
    print (args)

    class Meter:
        def __init__(self, alpha=0):
            self.mean = np.nan
            # self.std = np.nan
            self.alpha = alpha
            # self.steps = 0
        def update(self, numbers):
            # self.steps += 1
            if not np.isnan(self.mean):
                self.mean = self.alpha * self.mean + (1-self.alpha) * numbers
                # self.mean = self.alpha * self.mean + (1-self.alpha) * np.mean(numbers)
                # self.std = self.alpha * self.std + (1-self.alpha) * np.std(numbers)
            else:
                self.mean = numbers
                # self.mean = np.mean(numbers)
                # self.std = np.std(numbers)

    if args.gpu >= 0:
        device = torch.device(f"cuda:{args.gpu}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")

    if args.env == 'sumo':
        assert not args.tabular

        state_dim = 120
        nact = 8
        MODEL = model.MLPControl

        if args.nproc > 1:
            # env = envs.ParallelEnvWrapper(envs.SumoEnv, args.nproc, **kwargs)
            env = None
            runner = runner.ParallelRunner(args.nproc)

            def get_act(x, player_a_or_b):
                if x == 'random':
                    return ('random',)
                elif x == 'builtin':
                    return ('builtin',)
                else:
                    return ('mlp', x.policy.state_dict())
        else:
            kwargs = {}
            if args.method.startswith('single'):
                kwargs['single_i'] = int(args.method == 'singley')
            env = envs.SumoEnv(**kwargs)
            # state_dim = env.state_dim
            # nact = env.nact

            def get_act(x, player_a_or_b):
                if x == 'random':
                    def act_func(state):
                        if state is not None:
                            return env.env.action_space[player_a_or_b].sample()
                elif x == 'builtin':
                    def act_func(state):
                        if state is not None:
                            # running into oppponent?
                            return env.env.action_space[player_a_or_b].sample() * 0
                else:
                    def act_func(state):
                        if state is not None:
                            # [0]raw_act, [1]tanh(raw_act)
                            return runner.act_cont(x, state[player_a_or_b], device=device)[1]
                return act_func

        if args.pretrain:
            try:
                i0 = int(args.save.split('/')[-2])
            except:
                print ("[warning] failed to parse pretrain agent id")
                i0 = np.random.randint(8)

            _candidates = ['x0_4', 'y0_4', 'x1_4', 'y1_4', 'x2_4', 'y2_4', 'x3_4', 'y3_4',
                           'x4_4', 'y4_4', 'x5_4', 'y5_4', 'x6_4', 'y6_4', 'x7_4', 'y7_4']
            X = [ MODEL(state_dim, nact,
                        pretrain='./sumo/pretrain_v3/pretrain%s.pt' % (_candidates[(i0 + i*2)%len(_candidates)]),
                        init_std=args.init_std, trainable_std=not args.fix_std).to(device)
                    for i in range(args.nagent) ]
            Y = [ MODEL(state_dim, nact,
                        pretrain='./sumo/pretrain_v3/pretrain%s.pt' % (_candidates[(i0 + i*2 + 1)%len(_candidates)]),
                        init_std=args.init_std, trainable_std=not args.fix_std).to(device)
                    for i in range(args.nagent) ]
        else:
            X = [ MODEL(state_dim, nact).to(device) for i in range(args.nagent) ]
            Y = [ MODEL(state_dim, nact).to(device) for i in range(args.nagent) ]


    def test(x, y, gamma=args.test_gamma, s=""):
        act_x = get_act(x, player_a_or_b=0)
        act_y = get_act(y, player_a_or_b=1)
        win_stat = np.empty((args.test_batch,), dtype=np.byte)
        # rews = runner.test_rollout(env, args.test_T, args.test_batch, act_x, act_y, gamma, win_stat=win_stat)
        rews = runner.test_rollout(env, args.test_T, args.test_batch, act_x, act_y, gamma)
        win_stat = rews
        print (s + "\ttest rew %7.4f +/- %6.4f [%7.4f,%7.4f] (win %.1f%% draw %.1f%% lose %.1f%%)" % (
            rews.mean(), rews.std(), rews.min(), rews.max(),
            (win_stat > 0).mean()*100, (win_stat == 0).mean()*100, (win_stat < 0).mean()*100)
        )

    if args.env == 'sumo':
        def test_sequence():
            test(x, y, s="x vs y   ")
            tmpy.load_state_dict(histY[0][i])
            tmpx.load_state_dict(histX[0][i])
            test(x, tmpy, s="x vs init")
            test(tmpx, y, s="init vs y")
    else:
        def test_sequence():
            test(x, y, s="x vs y   ")
            test(x, 'random', s="x vs rand")
            test('random', y, s="rand vs y")
            test(x, 'builtin', s="x vs built")
            test('builtin', y, s="built vs y")

    # X = [ MODEL().to(device) for i in range(args.nagent) ]
    rlX = [ algo.A2C_PPO(x, args) for x in X ]
    histX = []
    tmpx = deepcopy(X[0]).eval()

    # Y = [ MODEL().to(device) for i in range(args.nagent) ]
    rlY = [ algo.A2C_PPO(y, args) for y in Y ]
    histY = []
    tmpy = deepcopy(Y[0]).eval()

    num_episodes = 0
    hist_num_ep = []

    def checkpoint():
        histX.append([ deepcopy(x.state_dict()) for x in X ])
        histY.append([ deepcopy(y.state_dict()) for y in Y ])
        hist_num_ep.append(num_episodes)
        if len(args.save):
            torch.save(histX, args.save + 'histX.pt')
            torch.save(histY, args.save + 'histY.pt')

    checkpoint()

    # sanity checks
    # test('random', 'random', s="rand  vs rand")
    # test('random', 'builtin', s="rand  vs built")
    # test('builtin', 'random', s="built vs rand")
    test(X[0], 'random', s="x     vs rand")
    test(X[0], 'builtin', s="x     vs built")
    test(X[0], Y[0], s="x     vs y    ")
    test('random', Y[0], s="rand  vs y")
    test('builtin', Y[0], s="built vs y")

    start_time = time.time()

    if args.method[:6] == 'single':
        assert args.nagent == 1
        def opponent(states, av):
            mean = torch.zeros((len(states), nact))
            logstd = torch.full((len(states), nact), -5.0)
            return mean, logstd

        if args.method == 'singlex':
            x, y = X[0], opponent
            for it in range(args.niter):
                print ('\n', it, time.ctime())
                print ("x: compete w/ fixed agent")

                for inner in range(args.ninner):
                    sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                        env, args.T, args.batch, state_dim, x, y, record=1, device=device)
                    num_episodes += args.batch
                    raw_ep_rew_x = rx.sum(0)
                    lx = mx.float().sum(0).mean().item()

                    action_loss_x, value_loss_x, entropy_x, grad_norm_x, rx = rlX[0].update(sx, ax, rx, mx)
                    dis_ep_rew = rx[0].mean().item()

                    print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_x, value_loss_x, entropy_x, grad_norm_x,
                         raw_ep_rew_x.mean(), (raw_ep_rew_x>0).float().mean()*100, (raw_ep_rew_x<0).float().mean()*100, dis_ep_rew, lx))

                checkpoint()

                print ("episodes %d" % num_episodes)
                test(x, y, s="x vs y   ")
                test(x, 'random', s="x vs rand")
                # test(x, 'builtin', s="x vs built")
        #elif args.method == 'singley':
        else:
            x, y = opponent, Y[0]
            for it in range(args.niter):
                print ('\n', it, time.ctime())
                print ("y: compete w/ fixed agent")

                for inner in range(args.ninner):
                    # use the rewards directly from the env
                    _, _, ry, _,  sy, ay, _, my = runner.rollout(
                        env, args.T, args.batch, state_dim, x, y, record=2, device=device)
                    num_episodes += args.batch
                    raw_ep_rew_y = ry.sum(0)
                    ly = my.float().sum(0).mean().item()

                    action_loss_y, value_loss_y, entropy_y, grad_norm_y, ry = rlY[0].update(sy, ay, ry, my)
                    dis_ep_rew = ry[0].mean().item()

                    print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_y, value_loss_y, entropy_y, grad_norm_y,
                         raw_ep_rew_y.mean(), (raw_ep_rew_y>0).float().mean()*100, (raw_ep_rew_y<0).float().mean()*100, dis_ep_rew, ly))

                checkpoint()

                print ("episodes %d" % num_episodes)
                test(x, y, s="x vs y   ")
                test('random', y, s="rand vs y")
                # test('builtin', y, s="built vs y")


    if args.method == 'base':
        assert args.nagent == 1
        i = 0

        x, y = X[0], Y[0]
        for it in range(args.niter):
            print ('\n', it, time.ctime())

            # if not args.fix_x:
            idx = -1
            tmpy.load_state_dict(histY[idx][0])
            print ("x: compete w/ y[it=%d]" % (idx))

            for inner in range(args.ninner):
                sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                    env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
                num_episodes += args.batch
                raw_ep_rew_x = rx.sum(0)
                lx = mx.float().sum(0).mean().item()

                action_loss_x, value_loss_x, entropy_x, grad_norm_x, rx = rlX[0].update(sx, ax, rx, mx)
                dis_ep_rew = rx[0].mean().item()

                # print (rx[0].max(), rx[0].min())
                print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                    (action_loss_x, value_loss_x, entropy_x, grad_norm_x,
                     raw_ep_rew_x.mean(), (raw_ep_rew_x>0).float().mean()*100, (raw_ep_rew_x<0).float().mean()*100, dis_ep_rew, lx))

            # if not args.fix_y:
            idx = -1
            tmpx.load_state_dict(histX[idx][0])
            print ("y: compete w/ x[it=%d]" % (idx))

            for inner in range(args.ninner):
                _, _, _, _,  sy, ay, ry, my = runner.rollout(
                    env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
                num_episodes += args.batch
                raw_ep_rew_y = ry.sum(0)
                ly = my.float().sum(0).mean().item()

                action_loss_y, value_loss_y, entropy_y, grad_norm_y, ry = rlY[0].update(sy, ay, ry, my)
                dis_ep_rew = ry[0].mean().item()

                print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                    (action_loss_y, value_loss_y, entropy_y, grad_norm_y,
                     raw_ep_rew_y.mean(), (raw_ep_rew_y>0).float().mean()*100, (raw_ep_rew_y<0).float().mean()*100, dis_ep_rew, ly))

            test_sequence()
            checkpoint()
            elapsed_time = (time.time() - start_time) / 3600
            print ("episodes %d per-agent %d time %.2f eta %.2f" % (num_episodes, num_episodes // args.nagent,
                elapsed_time, elapsed_time / (it+1) * (args.niter-1-it)))


    if args.method == 'baserand' or args.method == 'baserandl':
        # assert args.nagent == 1
        always_last = args.method == 'baserandl'

        for it in range(args.niter):
            print ('\n', it, time.ctime())

            for i in range(args.nagent):
                x = X[i]
                j = np.random.randint(args.nagent)
                idx = -1 if always_last else np.random.randint(len(histY))
                tmpy.load_state_dict(histY[idx][j])
                print ("\nx%d: compete w/ y%d[it=%d]" % (i, j, idx))

                for inner in range(args.ninner):
                    sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                        env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
                    num_episodes += args.batch
                    raw_ep_rew_x = rx.sum(0)
                    lx = mx.float().sum(0).mean().item()

                    action_loss_x, value_loss_x, entropy_x, grad_norm_x, rx = rlX[i].update(sx, ax, rx, mx)
                    dis_ep_rew = rx[0].mean().item()

                    print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_x, value_loss_x, entropy_x, grad_norm_x,
                         raw_ep_rew_x.mean(), (raw_ep_rew_x>0).float().mean()*100, (raw_ep_rew_x<0).float().mean()*100, dis_ep_rew, lx))


                y = Y[i]
                j = np.random.randint(args.nagent)
                idx = -1 if always_last else np.random.randint(len(histX))
                tmpx.load_state_dict(histX[idx][j])
                print ("y%d: compete w/ x%d[it=%d]" % (i, j, idx))

                for inner in range(args.ninner):
                    _, _, _, _,  sy, ay, ry, my = runner.rollout(
                        env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
                    num_episodes += args.batch
                    raw_ep_rew_y = ry.sum(0)
                    ly = my.float().sum(0).mean().item()

                    action_loss_y, value_loss_y, entropy_y, grad_norm_y, ry = rlY[i].update(sy, ay, ry, my)
                    dis_ep_rew = ry[0].mean().item()

                    print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_y, value_loss_y, entropy_y, grad_norm_y,
                         raw_ep_rew_y.mean(), (raw_ep_rew_y>0).float().mean()*100, (raw_ep_rew_y<0).float().mean()*100, dis_ep_rew, ly))

                test_sequence()

            checkpoint()
            elapsed_time = (time.time() - start_time) / 3600
            print ("episodes %d per-agent %d time %.2f eta %.2f" % (num_episodes, num_episodes // args.nagent,
                elapsed_time, elapsed_time / (it+1) * (args.niter-1-it)))


    if args.method == 'basebest':
        assert args.nagent == 1
        i = 0

        x, y = X[0], Y[0]
        idx_best = 0

        for it in range(args.niter):
            print ('\n', it, time.ctime())
            print ("[idx_best=%d]" % idx_best)

            tmpy.load_state_dict(histY[idx_best][0])
            for inner in range(args.ninner):
                sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                    env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
                num_episodes += args.batch
                raw_ep_rew_x = rx.sum(0); raw_ep_rew_x_mean = raw_ep_rew_x.mean()
                lx = mx.float().sum(0).mean().item()

                action_loss_x, value_loss_x, entropy_x, grad_norm_x, rx = rlX[0].update(sx, ax, rx, mx)
                dis_ep_rew = rx[0].mean().item()

                print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                    (action_loss_x, value_loss_x, entropy_x, grad_norm_x,
                     raw_ep_rew_x_mean, (raw_ep_rew_x>0).float().mean()*100, (raw_ep_rew_x<0).float().mean()*100, dis_ep_rew, lx))


            tmpx.load_state_dict(histX[idx_best][0])
            for inner in range(args.ninner):
                _, _, _, _,  sy, ay, ry, my = runner.rollout(
                    env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
                num_episodes += args.batch
                raw_ep_rew_y = ry.sum(0); raw_ep_rew_y_mean = raw_ep_rew_y.mean()
                ly = my.float().sum(0).mean().item()

                action_loss_y, value_loss_y, entropy_y, grad_norm_y, ry = rlY[0].update(sy, ay, ry, my)
                dis_ep_rew = ry[0].mean().item()

                print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                    (action_loss_y, value_loss_y, entropy_y, grad_norm_y,
                     raw_ep_rew_y_mean, (raw_ep_rew_y>0).float().mean()*100, (raw_ep_rew_y<0).float().mean()*100, dis_ep_rew, ly))

            r = 0.5 * (raw_ep_rew_x_mean + raw_ep_rew_y_mean)
            if r > args.delta:
                print ('update best %d -> %d %7.4f(%7.4f, %7.4f)' % (idx_best, it + 1,
                    r, raw_ep_rew_x_mean, raw_ep_rew_y_mean))
                idx_best = it + 1
            else:
                print ('keep best %d since %7.4f(%7.4f, %7.4f) < 0' % (idx_best,
                    r, raw_ep_rew_x_mean, raw_ep_rew_y_mean))

            test_sequence()
            checkpoint()
            elapsed_time = (time.time() - start_time) / 3600
            print ("episodes %d per-agent %d time %.2f eta %.2f" % (num_episodes, num_episodes // args.nagent,
                elapsed_time, elapsed_time / (it+1) * (args.niter-1-it)))


    if args.method == 'const':
        # assert args.nagent > 1
        if args.c_adalr and args.target_gap is None:
            a_dual_gap = Meter(0.9)
        f = 1.0

        for it in range(args.niter):
            print ('\n', it, time.ctime())

            pairwise_rolls = np.empty((args.nagent, args.nagent), dtype=np.object)
            pairwise_r = np.zeros((args.nagent, args.nagent))
            for i in range(args.nagent):
                for j in range(args.nagent):
                    # x_i vs y_j
                    roll = runner.rollout(env, args.T, args.batch, state_dim, X[i], Y[j], record=3, device=device)
                    num_episodes += args.batch
                    pairwise_rolls[i, j] = roll
                    pairwise_r[i, j] = roll[2].sum(0).mean()
            fitness = (pairwise_r.mean(1) - pairwise_r.mean(0))/2
            print (np.array2string(pairwise_r, precision=2), f'best {fitness.argmax()} [{fitness.max():.3f}]')

            # # extra-gradient method, rescind one step if the current agent is chosen as opponent
            # if args.c_eg and it > 0:
            #     for i in range(args.nagent):
            #         X[i].load_state_dict(histX[-2][i])
            #         Y[i].load_state_dict(histY[-2][i])

            for i in range(args.nagent):
                x = X[i]

                min_j = np.argmin(pairwise_r[i, :])
                min_r = pairwise_r[i, min_j]
                sx, ax, rx, mx = pairwise_rolls[i, min_j][:4]
                tmpy.load_state_dict(histY[-1][min_j])

                # extra-gradient method, rescind one step if the current agent is chosen as opponent
                if args.c_eg and i == min_j and it > 0:
                    X[i].load_state_dict(histX[-2][i])
                    print ("x%d: compete w/ y%d [rew %7.4f] rescind" % (i, min_j, min_r))
                else:
                    print ("x%d: compete w/ y%d [rew %7.4f]" % (i, min_j, min_r))

                max_j = np.argmax(pairwise_r[:, i])
                max_r = pairwise_r[max_j, i]
                sy, ay, ry, my = pairwise_rolls[max_j, i][4:8]
                tmpx.load_state_dict(histX[-1][max_j])

                # extra-gradient method, rescind one step if the current agent is chosen as opponent
                if args.c_eg and i == max_j and it > 0:
                    Y[i].load_state_dict(histY[-2][i])
                    print ("y%d: compete w/ x%d [rew %7.4f] rescind" % (i, max_j, max_r))
                else:
                    print ("y%d: compete w/ x%d [rew %7.4f]" % (i, max_j, max_r))

                if args.c_adalr:
                    dual_gap = max_r - min_r
                    if args.target_gap is None:
                        a_dual_gap.update(dual_gap)
                        f = np.clip(dual_gap / (a_dual_gap.mean + 1e-8), 1.0, 1.4)
                    else:
                        f = np.clip(dual_gap / args.target_gap, 1.0, 1.4)
                    # rlY[i].set_lr(f)
                    # rlX[i].set_lr(f)
                    print ("dual gap: %7.4f (a %7.4f), mult %.3f" % (dual_gap, a_dual_gap.mean, f))

                for inner in range(args.ninner):
                    if inner > 0:
                        sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                            env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
                        num_episodes += args.batch
                    raw_ep_rew_x = rx.sum(0)
                    lx = mx.float().sum(0).mean().item()

                    action_loss_x, value_loss_x, entropy_x, grad_norm_x, rx = rlX[i].update(sx, ax, rx, mx, f)
                    dis_ep_rew = rx[0].mean().item()

                    print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_x, value_loss_x, entropy_x, grad_norm_x,
                         raw_ep_rew_x.mean(), (raw_ep_rew_x>0).float().mean()*100, (raw_ep_rew_x<0).float().mean()*100, dis_ep_rew, lx))


                y = Y[i]

                for inner in range(args.ninner):
                    if inner > 0:
                        _, _, _, _, sy, ay, ry, my = runner.rollout(
                            env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
                        num_episodes += args.batch

                    raw_ep_rew_y = ry.sum(0)
                    ly = my.float().sum(0).mean().item()

                    action_loss_y, value_loss_y, entropy_y, grad_norm_y, ry = rlY[i].update(sy, ay, ry, my, f)
                    dis_ep_rew = ry[0].mean().item()

                    print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_y, value_loss_y, entropy_y, grad_norm_y,
                         raw_ep_rew_y.mean(), (raw_ep_rew_y>0).float().mean()*100, (raw_ep_rew_y<0).float().mean()*100, dis_ep_rew, ly))

                test_sequence()

            checkpoint()
            elapsed_time = (time.time() - start_time) / 3600
            print ("episodes %d per-agent %d time %.2f eta %.2f" % (num_episodes, num_episodes // args.nagent,
                elapsed_time, elapsed_time / (it+1) * (args.niter-1-it)))


    if args.method == 'constv':
        mix_steps = args.nagent // 2
        # inner_steps = args.ninner - mix_steps + 1
        inner_steps = args.ninner
        print (f'constv [n={args.nagent}, inner_steps={inner_steps}]')
        a_dual_gap = Meter(0.9)
        opp_id = [[i,i] for i in range(args.nagent)]

        for it in range(args.niter):
            print ('\n', it, time.ctime())

            # collect pairwise competitions
            pairwise_rolls = np.empty((args.nagent, args.nagent), dtype=np.object)
            pairwise_r = np.zeros((args.nagent, args.nagent))
            for i in range(args.nagent):
                for j in range(args.nagent):
                    # x_i vs y_j
                    sx,ax,rx,mx, sy,ay,ry,my = runner.rollout(env, args.T, args.batch, state_dim, X[i], Y[j], record=3, device=device)
                    num_episodes += args.batch
                    # store
                    pairwise_rolls[i, j] = (sx,ax,rx,mx, sy,ay,ry,my)
                    pairwise_r[i, j] = rx.sum(0).mean()
            # best among the population fit(i,j) = (f(xi,yj) - f(xj,yi))/2
            fitness_x, fitness_y = pairwise_r.mean(1), pairwise_r.mean(0)
            mean_r = fitness_x.mean()
            fitness = (fitness_x - fitness_y)/2
            print (np.array2string(pairwise_r, precision=2),
                f'best {fitness.argmax()} [{fitness.max():.3f}]',
                f'best_x {fitness_x.argmax()} [{fitness_x.max():.3f}]',
                f'best_y {fitness_y.argmin()} [{fitness_y.min():.3f}]',
                f'mean_r {mean_r:.4f}')

            # train all agents
            for i in range(args.nagent):
                fx = fy = 1.0
                max_r, min_r = pairwise_r[i, :].max(), pairwise_r[:, i].min()
                dual_gap = max_r - min_r
                a_dual_gap.update(dual_gap)
                print ("\ndual gap: %7.4f=(%.4f)-(%.4f) (a %7.4f)" % (dual_gap, max_r, min_r, a_dual_gap.mean))
                x, y = X[i], Y[i]

                # train X[i]
                for inner in range(inner_steps):
                    if inner >= 1:
                        sx, ax, rx, mx,  _, _, _, _ = runner.rollout(
                            env, args.T, args.batch, state_dim, x, tmpy, record=1, device=device)
                        num_episodes += args.batch
                    else:
                        sx, ax, rx, mx = [torch.cat(
                            [pairwise_rolls[i, j][_] for j in range(args.nagent)], 1)
                            for _ in (0,1,2,3)]

                    raw_ep_rew_x = rx.sum(0)
                    lx = mx.float().sum(0).mean().item()

                    action_loss_x, value_loss_x, entropy_x, grad_norm_x, dr = rlX[i].update(sx, ax, rx, mx, f=fx)
                    dis_ep_rew = dr[0].mean().item()

                    print ("x: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_x, value_loss_x, entropy_x, grad_norm_x,
                         raw_ep_rew_x.mean(), (raw_ep_rew_x>0).float().mean()*100, (raw_ep_rew_x<0).float().mean()*100, dis_ep_rew, lx))

                    if inner == 0:
                        dr = dr[0].view(args.nagent, args.batch).mean(1)
                        min_r, min_j = dr.min(0)
                        old_j = opp_id[i][0]
                        if min_r < dr[old_j] - args.delta:
                            opp_id[i][0] = min_j
                        else:
                            # maintain the old opponent
                            min_r, min_j = dr[old_j], old_j
                        tmpy.load_state_dict(histY[-1][min_j])
                        if args.c_adalr: fx = np.clip((dr.mean() - min_r) / args.target_gap, 1.0, 1.2)
                        print ("x%d: compete w/ y%d [old %d, rew %7.4f], fx %.3f, %s %.3f" % (i, min_j, old_j, min_r, fx,
                            np.array2string(dr.numpy(), precision=2), dr.mean()))

                # train Y[i]
                for inner in range(inner_steps):
                    if inner >= 1:
                        _, _, _, _,  sy, ay, ry, my = runner.rollout(
                            env, args.T, args.batch, state_dim, tmpx, y, record=2, device=device)
                        num_episodes += args.batch
                    else:
                        sy, ay, ry, my = [torch.cat(
                            [pairwise_rolls[j, i][_] for j in range(args.nagent)], 1)
                            for _ in (4,5,6,7)]

                    raw_ep_rew_y = ry.sum(0)
                    ly = my.float().sum(0).mean().item()

                    action_loss_y, value_loss_y, entropy_y, grad_norm_y, dr = rlY[i].update(sy, ay, ry, my, f=fy)
                    dis_ep_rew = dr[0].mean().item()

                    print ("y: act %7.4f\tval %.4f\tent %.4f\tgrad %.4f\trew %7.4f[%2.0f%%,%2.0f%%]\td_rew %7.4f\tlen %.2f" %
                        (action_loss_y, value_loss_y, entropy_y, grad_norm_y,
                         raw_ep_rew_y.mean(), (raw_ep_rew_y>0).float().mean()*100, (raw_ep_rew_y<0).float().mean()*100, dis_ep_rew, ly))

                    if inner == 0:
                        dr = dr[0].view(args.nagent, args.batch).mean(1)
                        max_r, max_j = dr.min(0)  # dr.max(0) # this is already a negative sign
                        old_j = opp_id[i][1]
                        if max_r < dr[old_j] - args.delta:
                            opp_id[i][1] = max_j
                        else:
                            # maintain the old opponent
                            max_r, max_j = dr[old_j], old_j
                        tmpx.load_state_dict(histX[-1][max_j])
                        if args.c_adalr: fy = np.clip((dr.mean() - max_r) / args.target_gap, 1.0, 1.2)
                        print ("y%d: compete w/ x%d [old %d, rew %7.4f], fy %.3f, %s %.3f" % (i, max_j, old_j, max_r, fy,
                            np.array2string(dr.numpy(), precision=2), dr.mean()))

                test_sequence()

            checkpoint()
            elapsed_time = (time.time() - start_time) / 3600
            print ("episodes %d per-agent %d time %.2f eta %.2f" % (num_episodes, num_episodes // args.nagent,
                elapsed_time, elapsed_time / (it+1) * (args.niter-1-it)))


    # elapsed_time = (time.time() - start_time) / 3600
    # print (f"elapsed_time {elapsed_time:.3f} h, per-agent {elapsed_time/args.nagent:.3f} h")

    del runner

    if len(args.save):
        utils.close_log()
