from modules import PlayLMP
import numpy as np
import torch
import yaml
import joblib

import argparse
import os
from tqdm import tqdm


@torch.no_grad()
def generate_skills(episode, model, min_skill_length, max_skill_length):
    epi_length = len(episode['observations'])
    skills = np.zeros([epi_length, max_skill_length - min_skill_length + 1, model.skill_dim], dtype=np.float32)
    for t in range(epi_length - min_skill_length - 1):
        max_length = np.minimum(epi_length - t, max_skill_length + 1)
        batch = {'observations': [], 'actions': []}
        for skill_length in range(min_skill_length, max_length):
            observations = episode['observations'][t:t + skill_length]
            actions = episode['actions'][t:t + skill_length]
            if skill_length < max_skill_length:
                pad_length = max_skill_length - skill_length
                observation_pad = np.repeat(
                    episode['observations'][t + skill_length][np.newaxis, ...],
                    pad_length, axis=0
                )
                action_pad = np.repeat(
                    episode['actions'][t + skill_length - 1][np.newaxis, ...],
                    pad_length, axis=0
                )
                action_pad[:, :-1] = 0.
                observations = np.concatenate([observations, observation_pad], axis=0)
                actions = np.concatenate([actions, action_pad], axis=0)
            batch['observations'].append(torch.as_tensor(observations))
            batch['actions'].append(torch.as_tensor(actions))
        for k in batch.keys():
            batch[k] = torch.stack(batch[k], dim=0)
        batch = model.preprocess_batch(batch, eval=True)
        skill_posterior = model.skill_recognition(batch['observations'], batch['actions'])
        skills[t, :max_length - min_skill_length] = skill_posterior.mean.cpu().detach().numpy()
    return skills


parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='calvin')
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

domain = args.env.split('-')[0]
exp_name = f'play_lmp_{domain}_s_{args.seed}'

config = yaml.load(open(f'configs/{domain}/play_lmp.yaml'), Loader=yaml.FullLoader)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = PlayLMP(**config['model_cfg'], dataset_cfg=config['dataset_cfg'], device=device)
checkpoint = torch.load(f'checkpoints/{domain}/{exp_name}/checkpoint_best.pt')
model.load_state_dict(checkpoint['model'])
print('Best epoch: %d' % checkpoint['epoch'])

train_data_pths = [os.path.join(config['dataset_cfg']['data_dir'], 'train_%d.pkl' % i) for i in range(35)]
val_data_pths = [os.path.join(config['dataset_cfg']['data_dir'], 'validation_%d.pkl' % i) for i in range(6)]

for pth in tqdm(train_data_pths):
    episode = joblib.load(pth)
    skills = generate_skills(
        episode, model,
        min_skill_length=config['dataset_cfg']['min_skill_length'],
        max_skill_length=config['dataset_cfg']['max_skill_length']
    )
    episode['skills'] = skills
    joblib.dump(episode, pth)

for pth in tqdm(val_data_pths):
    episode = joblib.load(pth)
    skills = generate_skills(
        episode, model,
        min_skill_length=config['dataset_cfg']['min_skill_length'],
        max_skill_length=config['dataset_cfg']['max_skill_length']
    )
    episode['skills'] = skills
    joblib.dump(episode, pth)
