import os
import argparse
import numpy as np
import torch
from optimal_agents.utils.loader import BASE, get_env, ModelParams
from optimal_agents.morphology import Morphology
from torchvision.utils import save_image

def get_image(path, morphology_index=0):
    if 'params.json' in os.listdir(path):
        print("PASSING")
        pass
    else:
        subdirs = [os.path.join(path, d) for d in os.listdir(path)]
        for d in subdirs:
            # print(d)
            if os.path.isdir(d) and 'params.json' in os.listdir(d) and '0.morphology.pkl' in os.listdir(d) and not os.path.basename(d).startswith('gen'):
                path = d
                break
        
    print(path)
    morphology_path = os.path.join(path, str(morphology_index) + ".morphology.pkl")
    
    if not os.path.isfile(morphology_path):
        print("CANNOT FIND", morphology_path)
        return None
    morphology = Morphology.load(morphology_path)
    
    params = ModelParams.load(path)
    if params['arena'] and 'Terrain' in params['arena']:
        params['arena'] = None

    if ' ' in params['env']:
        params['env'] = params['env'].split(' ')[0]
    env = get_env(params, morphology)
    obs = env.reset()
    
    for _ in range(100):
        env.step(env.action_space.sample())

    img = env.render(mode='rgb_array', height=450, width=640) / 255.0
    return img

def main_single(paths, output_path, rows=8):
    output_path = os.path.join(output_path, 'fig.png')

    for i, path in enumerate(paths):
        if not path.startswith('/'):
            paths[i] = os.path.join(BASE, path)
    print(len(paths))
    images = []
    for path in paths:
        img = get_image(path)
        if not img is None:
            images.append(img)
    print(len(images))
    images = [torch.from_numpy(np.transpose(image, (2,0,1)).copy()) for image in images]
    save_image(images, output_path, nrow=rows, padding=10, pad_value=0.7)

def main_batch(paths, output_path, rows=8):
    output_path = os.path.join(output_path, 'fig.png')
    for i, path in enumerate(paths):
        if not path.startswith('/'):
            paths[i] = os.path.join(BASE, path)
    images = []
    run_paths = []
    for path in paths:
        runs = [os.path.join(path, d) for d in os.listdir(path)]
        added_imgs = 0
        for run in runs:
            img = get_image(run)
            if not img is None:
                images.append(img)
                added_imgs += 1
            if added_imgs == 6:
                break

    print(len(images))
    images = [torch.from_numpy(np.transpose(image, (2,0,1)).copy()) for image in images]
    save_image(images, output_path, nrow=rows, padding=10, pad_value=0.2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--generation-path", "-p", default=None, type=str, nargs='+', required=True)
    parser.add_argument("--output-path", "-o", default='.', type=str, required=False)
    parser.add_argument("--rows", "-r", default=8, type=int, required=False)
    args = parser.parse_args()
    main_batch(args.generation_path, args.output_path, args.rows)
