from collections import namedtuple

Agent = namedtuple('Agent', ['path', 'it', 'id'])

def process_eval(qin, qout, args, proc_idx):
    import numpy as np
    import torch    # import torch here to ensure randomness
    import envs, model, algo, runner
    import multiprocessing as mp

    buf = {}
    def buffered_load(path):
        # torch.load(path, map_location=device)
        if path in buf: return buf[path]
        r = torch.load(path)
        buf[path] = r
        return r

    # pid = os.getpid()
    device = torch.device("cuda:%d" % (
        proc_idx % torch.cuda.device_count()) if args.gpu else "cpu")
    # ensure proper randomness; need manual seeding if 'fork', otherwise no need if 'spawn'
    # np.random.seed(proc_idx)
    # torch.manual_seed(proc_idx)
    # torch.cuda.manual_seed_all(proc_idx)
    # print ("[proc %d] started" % proc_idx, device, ':', torch.randn(()), np.random.rand())

    if args.env == 'sumo':
        env = envs.SumoEnv()
        state_dim = env.state_dim
        nact = env.nact

        MODEL = model.MLPControl

        xnn = MODEL().to(device).eval()
        ynn = MODEL().to(device).eval()

        def get_act(a, player_a_or_b):
            if a.path == 'random':
                def act_func(state):
                    if state is not None:
                        return env.env.action_space[player_a_or_b].sample()
            elif a.path == 'builtin':
                if a.id is None:
                    def act_func(state):
                        if state is not None:
                            # running into oppponent?
                            return env.env.action_space[player_a_or_b].sample() * 0
                else:
                    s = torch.load('./sumo/pretrain_v3/pretrain{}{}_4.pt'.format(
                        'x' if player_a_or_b == 0 else 'y', a.id))
                    a = xnn if player_a_or_b == 0 else ynn
                    a.load_state_dict(s)
                    def act_func(state):
                        if state is not None:
                            return runner.act_cont(a, state[player_a_or_b])[1]
            else:
                if player_a_or_b == 0:
                    histX = buffered_load(a.path + '/histX.pt')
                    if isinstance(a.id, tuple):
                        raise #s = histX[abs(a.it)][a.id[0]]
                    else:
                        s = histX[abs(a.it)][a.id]
                    xnn.policy.load_state_dict({k[7:]: v for k,v in s.items() if k[:6]=='policy'})
                    if args.temp: xnn.policy.logstd.data -= args.temp
                    a = xnn
                else:
                    histY = buffered_load(a.path + '/histY.pt')
                    if isinstance(a.id, tuple):
                        raise #s = histY[abs(a.it)][a.id[1]]
                    else:
                        s = histY[abs(a.it)][a.id]
                    # ynn.load_state_dict(s)
                    ynn.policy.load_state_dict({k[7:]: v for k,v in s.items() if k[:6]=='policy'})
                    if args.temp: ynn.policy.logstd.data -= args.temp
                    a = ynn
                def act_func(state):
                    if state is not None:
                        # [0]raw_act, [1]tanh(raw_act)
                        return runner.act_cont(a, state[player_a_or_b])[1]
            return act_func


    def test(x, y, gamma=args.gamma, s=""):
        act_x = get_act(x, player_a_or_b=0)
        act_y = get_act(y, player_a_or_b=1)
        rews = runner.test_rollout(env, args.T, args.batch, act_x, act_y, gamma)
        r_mean = rews.mean()
        # print ("[%d] %7.4f +/- %6.4f (win %.1f%% draw %.1f%% lose %.1f%%) " % (
        #     pid,
        #     r_mean, rews.std(),
        #     (rews > 0).mean()*100, (rews == 0).mean()*100, (rews < 0).mean()*100
        #     ) + s)
        return rews, r_mean

    while True:
        params = qin.get()
        if params is None: break
        i, j, a, b = params
        # if a is b: continue

        r, r_mean = test(a, b, s=a.path + ' ' + b.path)
        assert r.dtype == np.float32
        # r = r.astype(np.int8)  # assume r \in {-1,0,1}, to save space
        qout.put((i, j, r, r_mean))



if __name__ == '__main__':
    import time, os, sys, itertools
    # from copy import deepcopy
    from pprint import pprint
    from tqdm import tqdm
    import numpy as np
    import multiprocessing as mp

    import tour_config, pickle

    import argparse

    parser = argparse.ArgumentParser()

    # Misc
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--save', type=str, default="")
    parser.add_argument('--gpu', action='store_true', default=False)

    # MDP definition
    parser.add_argument('--env', type=str, default='soccer')
    parser.add_argument('--T', type=int, default=100, help='maximum episode length')
    parser.add_argument('--gamma', type=float, default=1.0, help='discount factor')

    # Meta-level algorithm
    parser.add_argument('--batch', type=int, default=10)
    parser.add_argument('--nproc', type=int, default=1)

    parser.add_argument('--config', type=str, default='soccer_tab')
    parser.add_argument('--temp', type=float, default=None, help='evaluation temperature')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--preload', type=str, default='')

    args = parser.parse_args()
    if not args.save:
        save_file = 'tour/%s_b%d_%s.pkl' % (args.config, args.batch, time.strftime("%Y%m%d-%H%M%S"))
    else:
        save_file = args.save

    mp.set_start_method('spawn')
    qin = mp.Queue()
    qout = mp.Queue()
    procs = [mp.Process(target=process_eval, args=(qin, qout, args, _)) for _ in range(args.nproc)]
    for p in procs: p.start()


    agents = getattr(tour_config, args.config)

    # expand iterations in agents
    agents_expand = []
    # expand path
    for a in agents:
        if ':' in a.path:
            i = a.path.rfind('/')
            b, e = a.path[i+1:].split(':')
            b, e = int(b), int(e)
            while b <= e:
                agents_expand.append(Agent(a.path[:i+1] + str(b), a.it, a.id))
                b += 1
        elif ',' in a.path:
            i = a.path.rfind('/')
            for s in a.path[i+1:].split(','):
                agents_expand.append(Agent(a.path[:i+1] + s, a.it, a.id))
        else:
            agents_expand.append(a)
    agents = agents_expand
    agents_expand = []
    # expand iteration
    for a in agents:
        if isinstance(a.it, list):
            for it in a.it:
                agents_expand.append(Agent(a.path, it, a.id))
        else:
            agents_expand.append(a)
    agents = agents_expand
    agents_expand = []
    # expand id
    for a in agents:
        if isinstance(a.id, list):
            for i in a.id:
                agents_expand.append(Agent(a.path, a.it, i))
        else:
            agents_expand.append(a)
    agents = agents_expand


    n = len(agents)

    if args.preload:
        with open(args.preload, 'rb') as f: d = pickle.load(f)
        assert d['agents'] == agents
        results_mat = d['results_mat']
        results_r = d['results_r']
    else:
        results_mat = np.empty((n,n), dtype=np.object)
        results_r = np.full((n,n), np.nan)
        # results_dict = {}

    # jobs = list(itertools.permutations(enumerate(agents), 2))
    # for job in jobs:
    #     qin.put((job[0][0], job[1][0], job[0][1], job[1][1]))
    jobs = []

    if args.fast:
        for i in range(n):
            a = agents[i]
            # for j in range(i+1, n):
            for j in range(i, n):
                b = agents[j]
                # (1) one of a,b, is random or builtin
                # (2) inter run competitions, at the same iter
                # (3) inner run competitions, w/ the same id, across iters
                if ((a.path == 'random' or a.path == 'builtin' or
                     b.path == 'random' or b.path == 'builtin') or
                    (a.it == b.it) or
                    (a.path == b.path and a.id == b.id)
                ):
                    if results_mat[i, j] is None:
                        jobs.append((i, j, a, b))
                        if i != j:
                            jobs.append((j, i, b, a))
    else:
        for i in range(n):
            a = agents[i]
            # for j in range(i+1, n):
            for j in range(i, n):
                b = agents[j]
                if ((a.path == 'random' or a.path == 'builtin' or
                     b.path == 'random' or b.path == 'builtin') or
                    (a.it >= 0 and b.it >= 0) or
                    (a.path == b.path and a.id == b.id and (
                        (a.it < 0 and b.it >= 0) or (a.it >= 0 and b.it < 0)))
                ):
                    if results_mat[i, j] is None:
                        jobs.append((i, j, a, b))
                        if i != j:
                            jobs.append((j, i, b, a))

    # pprint (jobs)
    np.random.shuffle(jobs)  # decorrelate
    for job in jobs:
        qin.put(job)

    # for i, a in enumerate(agents):
    #     for j, b in enumerate(agents):
    #         if a is b: continue
    #         qin.put((i, j, a, b))

    # for _ in range(n * (n-1)):
    for _ in tqdm(range(len(jobs))):
        i, j, r, r_mean = qout.get()
        results_mat[i, j] = r
        results_r[i, j] = r_mean
        # results_dict[(a, b)] = r
        # print (agents[i], agents[j], r_mean)
        if _ % 20 == 0:
            # save partial results
            print ('saved partial to ' + save_file)
            with open(save_file, 'wb') as f:
                d = dict(results_mat=results_mat, results_r=results_r, agents=agents, args=args)
                pickle.dump(d, f)

    print (np.array2string(results_r*100, precision=1))

    results_mean_x = np.nanmean(results_r,1)
    results_mean_y = np.nanmean(results_r,0)
    results_mean = (results_mean_x - results_mean_y) / 2
    for a, r, rx, ry in zip(agents, results_mean.tolist(), results_mean_x.tolist(), results_mean_y.tolist()):
        print (a, r, rx, ry)

    ###############################
    ### run logistic regression
    import torch

    # m = n * (n-1) // 2
    # m = len(jobs) // 2
    m = (~np.isnan(results_r)).astype(np.int).sum()

    #A = np.zeros((m, n), dtype=np.float32)
    #b = np.zeros((m), dtype=np.float32)
    A = torch.zeros((m, n))
    b = torch.zeros((m, 1))

    k = 0
    for i in range(n):
        for j in range(i+1, n):
            if not np.isnan(results_r[i, j]):
                # [-1,1] -> [0,1]
                score = (results_r[i, j] - results_r[j, i]) * 0.25 + 0.5
                A[k,i] = -1
                A[k,j] = 1
                b[k] = score
                k += 1

    # A = [-1  1       ...
    #      -1     1    ...
    #      -1        1 ...
    #      ...
    #         -1  1    ...
    #         -1     1 ...
    #      ...            ]
    # b = [0.6 0.5 0.4 ...]^T
    # print (A, b)

    elo = torch.nn.Parameter(torch.zeros(n, 1))

    elo_last = 100.0 # last_loss = 0.0
    opt = torch.optim.Adam([elo], 1.0)
    for i in range(10000):
        logits = torch.mm(A, elo) * (-np.log(10)/400)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(input=logits, target=b)
        # sigmoid(x) = 1/(1+exp(-x))
        # sigmoid[(eb - ea)*(-ln(10)/400)] = 1/(1+exp([ (eb-ea) * ln(10)/400 ])) = 1/(1 + 10^[(eb-ea)/400])
        # which is the Elo expression.
        # loss = -(b * torch.nn.functional.logsigmoid(logits) + (1-b) * torch.nn.functional.logsigmoid(-logits)).mean()

        if i % 100 == 0:
            print (i, loss.item())
        #if abs(last_loss - loss.item()) < 1e-5:
        # last_loss = loss.item()
        if (elo.data - elo_last).norm() < 1e-2:
            break
        elo_last = elo.data.clone()

        opt.zero_grad()
        loss.backward()
        # reg = elo.pow(2).mean() * 1e-7
        # (loss + reg).backward()
        opt.step()
    # print (i, loss.item(), reg.item())
    print (i, loss.item())

    # opt = torch.optim.LBFGS([elo], lr=10.0)
    # def closure():
    #     opt.zero_grad()
    #     logits = torch.mm(A, elo) * (-np.log(10)/400)
    #     loss = torch.nn.functional.binary_cross_entropy_with_logits(input=logits, target=b)
    #     reg = elo.pow(2).mean() * 1e-7
    #     loss = loss + reg
    #     loss.backward()
    #     return loss
    # for i in range(100000):
    #     loss = opt.step(closure)
    #     if i % 1000 == 0:
    #         print (i, loss.item())
    #     if (elo.data - elo_last).norm() < 1e-2:
    #         break
    #     elo_last = elo.data.clone()
    # print (i, loss.item())


    elo_list = elo[:,0].tolist()
    for a, e in zip(agents, elo_list):
        print (a, e - elo_list[0] + 1000)


    print ('saved to ' + save_file)
    with open(save_file, 'wb') as f:
        # print (globals().keys())
        d = dict(results_mat=results_mat, results_r=results_r, results_mean=results_mean,
            elo_list=elo_list, agents=agents, args=args, A=A.numpy(), b=b.numpy(), m=m, n=n)
        pickle.dump(d, f)


    for p in procs: qin.put(None)
    for p in procs: p.join()
    qin.close()
    qout.close()
