#!/usr/bin/env python
# coding: utf-8

# In[6]:


import sys
sys.path.insert(0, '..')
import numpy as np
import itertools

versions = list(f'{p[0]}{p[1]}' for p in itertools.product(['mlp','lstm'], [1,2,3]))

import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import model
from algo import evaluate_actions
from utils import print, setup_log

class ExpertDataset(Dataset):
    def __init__(self, file, player=2):
        print (f'loading {file}.. ', end='')
        z = np.load(f'data/{file}.npz', allow_pickle=True)
        self.obs = np.concatenate(z['obs']).astype(np.float32)
        self.act = np.concatenate(z['act']).astype(np.float32)
        if player < 2:
            self.obs = self.obs[:, player]
            self.act = self.act[:, player]
        assert len(self.obs) == len(self.act)
        print (len(self.obs))
    def __len__(self):
        return len(self.obs)
    def __getitem__(self, idx):
        return self.obs[idx], self.act[idx]

def pretrain(agent='lstm1', x_or_y='x', save='pretrainx0', batch_size=256, epochs=5, entropy_coef=0.01, weight_decay_coef=1e-3):
    if x_or_y == 'x':
        datasets = [ExpertDataset(f'{agent}_vs_{v}', 0) for v in versions]
    else:
        datasets = [ExpertDataset(f'{v}_vs_{agent}', 1) for v in versions]
    dataset = ConcatDataset(datasets)
    print (f'training {agent} on {len(dataset)} state-action pairs')
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    m = model.MLPControl(120, 8, init_std=1.0, trainable_std=True)
    opt = torch.optim.Adam(m.parameters(), lr=0.001)
    for epoch in range(epochs):
        for i, (obs, act) in enumerate(dataloader):
            opt.zero_grad()
            mean, std = m(obs, av=1)

            act.clamp_(-1., 1.)
            logprob, entropy = evaluate_actions((mean, std), act)

            action_loss = -logprob.mean()
            entropy = entropy_coef * entropy.mean()
            weight_decay = weight_decay_coef * sum(p.pow(2).sum() for p in m.policy.parameters())
            (action_loss - entropy + weight_decay).backward()
            opt.step()
            if i % 1000 == 0:
                print (epoch, i, '\t', action_loss.item(), entropy.item(), weight_decay.item())
            # a = mean.tanh()
            # action_loss = (a - act).pow(2).mean()
        save_epoch = save + f'_{epoch}.pt'
        torch.save(m.state_dict(), save_epoch)
        print (f'saved to {save_epoch}\n')

if __name__ == '__main__':
    if len(sys.argv) <= 3:
        versions += ['mlp1', 'lstm1']
        # print (versions)
        for i, agent in enumerate(versions):
            for x_or_y in ['x','y']:
                # pretrain(agent, x_or_y, 'pretrain/pretrain' + x_or_y + str(i))
                print (agent, x_or_y, 'pretrain/pretrain' + x_or_y + str(i))
    else:
        setup_log(sys.argv[3] + '.txt')
        pretrain(*sys.argv[1:])
