import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time, os, sys
from copy import deepcopy
from pprint import pprint
from tqdm import tqdm
sys.path.insert(0,'..')
import envs, gym
import model
import algo
import runner

log_file = open('./log.txt', 'w')
print_ = print
def print(*args, **kwargs):
    print_(*args, **kwargs)
    print_(*args, **kwargs, file=log_file)

# state_dim = 15*15
# nact = 15*15

class RenjuDataset(Dataset):
    def __init__(self):
        if os.path.exists('data.npz'):
            d = np.load('data.npz')
            self.X, self.Y = d['X'], d['Y']
            print ('load existing data.npz (n = %d)' % len(self.X))
            return

        # env = envs.RenjuEnv(15)
        env = gym.make('Renju{0}x{0}-learning-v0'.format(15))

        with open('boards.txt', 'r') as f:
            lines = f.read().split('\n')
        X, Y = [], []
        for l in tqdm(lines):
            if not len(l): continue
            o = env.reset()
            l = l.split(' ')
            for move in l:
                if move == '--': continue # pass
                # try:
                col = ord(move[0]) - ord('a')
                row = int(move[1:]) - 1
                action = row * 15 + col
                # except:
                #     print (l, move)
                # print (o.reshape(15,15), action, move, col, row)
                X.append(o.astype(np.uint8))
                Y.append(action)
                o,r,t,i = env.step(action)
                if t: break
                # print (i)
        self.X = np.stack(X)
        self.Y = np.stack(Y)

        np.savez('data.npz', X=self.X, Y=self.Y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


dataset = RenjuDataset()
dataloader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=1)

use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")

epochs = 1
ent_coef = 0.01

for seed in range(16):
    print ('seed %d' % seed)

    agent = model.FCN(board_size=15, pretrain=None, pad_m1=False)
    opt = torch.optim.RMSprop(agent.parameters(), 0.001, eps=1e-5, alpha=0.99)

    for epoch in range(epochs):
        agent.to(device)
        batch_n = len(dataloader)
        for batch_idx, (batch_x, batch_y) in enumerate(dataloader):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            act_logits = agent(batch_x, av=1)
            loss = F.cross_entropy(act_logits, batch_y)
            acc = (act_logits.argmax(-1) == batch_y).float().mean()
            entropy = -(act_logits.log_softmax(-1) * act_logits.softmax(-1)).sum(-1).mean()
            if batch_idx % 100 == 0:
                print ("[%5d / %5d] loss %.4f\tentropy %.4f\tacc %.4f" % (
                    batch_idx, batch_n, loss.item(), entropy.item(), acc.item()))

            opt.zero_grad()
            (loss - ent_coef * entropy).backward()
            opt.step()

        # filename = 'k553_seed%d_epoch_%d.pt' % (seed, epoch)
        try: os.makedirs('k553_epoch%d' % (epoch))
        except: pass
        filename = 'k553_epoch%d/%d.pt' % (epoch, seed)
        torch.save(agent.cpu().state_dict(), filename)
        print ("saved " + filename)

log_file.close()
