from collections import namedtuple

Agent = namedtuple('Agent', ['path', 'it', 'id'])

def process_eval(qin, qout, args, proc_idx):
    import numpy as np
    import torch    # import torch here to ensure randomness
    import envs
    import model
    import algo
    import runner
    # import multiprocessing as mp

    # pid = os.getpid()
    device = torch.device("cuda:%d" % (
        proc_idx % torch.cuda.device_count()) if args.gpu else "cpu")
    # ensure proper randomness
    # np.random.seed(proc_idx)
    print ("[proc %d] started" % proc_idx, device, torch.randn(()), np.random.rand())

    if args.env == 'soccer':
        # env = [envs.SoccerEnv(0) for _ in range(args.nproc)]
        env = envs.SoccerEnv(0)
        state_dim = 5
        nact = 5

        xtab = model.TabularSoccer().to(device)
        ytab = model.TabularSoccer().to(device)
        # xmlp = model.MLPSoccer().to(device)
        # ymlp = model.MLPSoccer().to(device)

        from envs import SoccerBuiltinAgent
        builtin = SoccerBuiltinAgent()

        def get_act(a, player_a_or_b):
            if a.it is None and a.path == 'random':
                def act_func(state):
                    return np.random.randint(nact)
            elif a.it is None and a.path == 'builtin':
                # builtin = SoccerBuiltinAgent(player_a_or_b, 0.5)
                def act_func(state):
                    if state is None:
                        builtin.reset_agent(player_a_or_b, a.id)
                        # builtin.reset_agent(player_a_or_b, 0.5)
                    else:
                        return builtin.act(state)
            else:
                if player_a_or_b == 0:
                    histX = torch.load(a.path + '/histX.pt')
                    s = histX[a.it][a.id]
                    xtab.load_state_dict(s)
                    a = xtab
                else:
                    histY = torch.load(a.path + '/histY.pt')
                    s = histY[a.it][a.id]
                    ytab.load_state_dict(s)
                    a = ytab
                def act_func(state):
                    if state is not None:
                        return runner.act(a, state)
            return act_func

    elif args.env == 'gomoku':
        env = envs.RenjuEnv()
        state_dim = 81
        nact = 81

        MODEL = model.FCN

        pretrain = MODEL(pretrain='./gomoku/k553_epoch0/0.pt', noise=False, verbose=0).to(device)

        xnn = MODEL(pretrain=None, noise=False).to(device)
        ynn = MODEL(pretrain=None, noise=False).to(device)

        def get_act(a, player_a_or_b):
            if a.path == 'random':
                def act_func(state):
                    if state is not None:
                        valid = np.where(state == 0)[0]
                        return np.random.choice(valid)
            elif a.path == 'builtin':
                if a.id is None:
                    s = pretrain
                else:
                    s = xnn if player_a_or_b == 0 else ynn
                    s.load_state_dict(
                        torch.load('./gomoku/k553_epoch0/%d.pt' % a.id, map_location=device)
                    )
                def act_func(state):
                    if state is not None:
                        return runner.act(s, state, device=device, temp=args.temp)
            else:
                if player_a_or_b == 0:
                    histX = torch.load(a.path + '/histX.pt', map_location=device)
                    if isinstance(a.id, tuple):
                        s = histX[a.it][a.id[0]]
                    else:
                        s = histX[a.it][a.id]
                    xnn.load_state_dict(s)
                    a = xnn
                else:
                    histY = torch.load(a.path + '/histY.pt', map_location=device)
                    if isinstance(a.id, tuple):
                        s = histY[a.it][a.id[1]]
                    else:
                        s = histY[a.it][a.id]
                    ynn.load_state_dict(s)
                    a = ynn
                def act_func(state):
                    if state is not None:
                        return runner.act(a, state, device=device, temp=args.temp)
            return act_func


    def test(x, y, gamma=args.gamma, s=""):
        act_x = get_act(x, player_a_or_b=0)
        act_y = get_act(y, player_a_or_b=1)
        rews = runner.test_rollout(env, args.T, args.batch, act_x, act_y, gamma)
        r_mean = rews.mean()
        # print ("[%d] %7.4f +/- %6.4f (win %.1f%% draw %.1f%% lose %.1f%%) " % (
        #     pid,
        #     r_mean, rews.std(),
        #     (rews > 0).mean()*100, (rews == 0).mean()*100, (rews < 0).mean()*100
        #     ) + s)
        return rews, r_mean

    while True:
        params = qin.get()
        if params is None: break
        i, j, a, b = params
        # if a is b: continue

        r, r_mean = test(a, b, s=a.path + ' ' + b.path)
        r = r.astype(np.int8)  # assume r \in {-1,0,1}, to save space
        qout.put((i, j, r, r_mean))



if __name__ == '__main__':
    import time, os, sys
    # from copy import deepcopy
    # from pprint import pprint
    from tqdm import tqdm
    import numpy as np
    import multiprocessing as mp

    import tour_config, pickle

    import argparse

    parser = argparse.ArgumentParser()

    # Misc
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--save', type=str, default="")
    parser.add_argument('--gpu', action='store_true', default=False)

    # MDP definition
    parser.add_argument('--env', type=str, default='soccer')
    parser.add_argument('--T', type=int, default=100, help='maximum episode length')
    parser.add_argument('--gamma', type=float, default=1.0, help='discount factor')

    # Meta-level algorithm
    parser.add_argument('--batch', type=int, default=10, help='niter*ninner*batch = total num episodes')
    parser.add_argument('--nproc', type=int, default=1)

    parser.add_argument('--config', type=str, default='soccer_tab')
    parser.add_argument('--temp', type=float, default=None, help='evaluation temperature')

    args = parser.parse_args()

    mp.set_start_method('spawn')
    qin = mp.Queue()
    qout = mp.Queue()
    procs = [mp.Process(target=process_eval, args=(qin, qout, args, _)) for _ in range(args.nproc)]
    for p in procs: p.start()


    agents = getattr(tour_config, args.config)

    # expand iterations in agents
    agents_expand = []
    # expand path
    for a in agents:
        if ':' in a.path:
            i = a.path.rfind('/')
            b, e = a.path[i+1:].split(':')
            b, e = int(b), int(e)
            while b <= e:
                agents_expand.append(Agent(a.path[:i+1] + str(b), a.it, a.id))
                b += 1
        else:
            agents_expand.append(a)
    agents = agents_expand
    agents_expand = []
    # expand iteration
    for a in agents:
        if isinstance(a.it, list):
            for it in a.it:
                agents_expand.append(Agent(a.path, it, a.id))
        else:
            agents_expand.append(a)
    agents = agents_expand
    agents_expand = []
    # expand id
    for a in agents:
        if isinstance(a.id, list):
            for i in a.id:
                agents_expand.append(Agent(a.path, a.it, i))
        else:
            agents_expand.append(a)
    agents = agents_expand


    n = len(agents)

    results_mat = np.empty((n,n), dtype=np.object)
    results_r = np.zeros((n,n))
    # results_dict = {}

    for i, a in enumerate(agents):
        for j, b in enumerate(agents):
            if a is b: continue
            qin.put((i, j, a, b))

    # for _ in range(n * (n-1)):
    for _ in tqdm(range(n * (n-1))):
        i, j, r, r_mean = qout.get()
        results_mat[i, j] = r
        results_r[i, j] = r_mean
        # results_dict[(a, b)] = r

    print (np.array2string(results_r*100, precision=1))

    results_mean_x = np.nanmean(results_r,1)
    results_mean_y = np.nanmean(results_r,0)
    results_mean = (results_mean_x - results_mean_y) / 2
    for a, r, rx, ry in zip(agents, results_mean.tolist(), results_mean_x.tolist(), results_mean_y.tolist()):
        print (a, r, rx, ry)


    ###########################
    ## run logistic regression
    import torch

    m = n * (n-1) // 2

    #A = np.zeros((m, n), dtype=np.float32)
    #b = np.zeros((m), dtype=np.float32)
    A = torch.zeros((m, n))
    b = torch.zeros((m, 1))

    k = 0
    for i in range(n):
        for j in range(i+1, n):
            # [-1,1] -> [0,1]
            score = (results_r[i, j] - results_r[j, i]) * 0.25 + 0.5
            A[k,i] = -1
            A[k,j] = 1
            b[k] = score
            k += 1

    # A = [-1  1       ...
    #      -1     1    ...
    #      -1        1 ...
    #      ...
    #         -1  1    ...
    #         -1     1 ...
    #      ...            ]
    # b = [0.6 0.5 0.4 ...]^T
    # print (A, b)

    elo = torch.nn.Parameter(torch.zeros(n, 1))

    opt = torch.optim.Adam([elo], 1.0)
    last_loss = 0.0
    for i in range(10000):
        logits = torch.mm(A, elo) * (-np.log(10)/400)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(input=logits, target=b)

        if i % 100 == 0:
            print (i, loss.item())
        if abs(last_loss - loss.item()) < 1e-5:
            break
        last_loss = loss.item()

        opt.zero_grad()
        loss.backward()
        opt.step()

    elo_list = elo[:,0].tolist()
    for a, e in zip(agents, elo_list):
        print (a, e - elo_list[0] + 1000)


    name = 'tour/%s_b%d_%s.pkl' % (args.config, args.batch, time.strftime("%Y%m%d-%H%M%S"))
    print ('saved to ' + name)
    with open(name, 'wb') as f:
        # print (globals().keys())
        d = dict(results_mat=results_mat, results_r=results_r, results_mean=results_mean,
            elo_list=elo_list, agents=agents, args=args, A=A, b=b, m=m, n=n)
        pickle.dump(d, f)


    for p in procs: qin.put(None)
    for p in procs: p.join()
    qin.close()
    qout.close()
