import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import xml
import xml.etree.ElementTree as ET
from gym.envs.mujoco.inverted_double_pendulum import InvertedDoublePendulumEnv
from gym.envs.mujoco.inverted_pendulum import InvertedPendulumEnv
from tqdm import tqdm
import ipdb
from PIL import Image


def get_image(env):
    pil_img = Image.fromarray(env.render(mode='rgb_array'))
    resized = pil_img.resize((64, 64))
    return np.asarray(resized)

def change_param_pendulum(pendulum_dir, xml_file, new_length=0.6, new_gravity=-9.81):

    """
    new_length ~ 0.6
    new_gravity ~ 9.81
    """

    xml_file_new = os.path.join(pendulum_dir, xml_file)
    tree = ET.parse(xml_file_new)
    root = tree.getroot()
    for elem in root.iter():
        if elem.tag == 'option':
            attributes = elem.attrib
            attributes['gravity'] = f'0 0 {new_gravity}'
            
    tree.findall('fromto')
    for elem in root.iter():
        if elem.tag == 'geom':
            attributes = elem.attrib
            if 'fromto' in attributes:
                attributes['fromto'] = f"0 0 0 0.001 0 {new_length}"
        
    output_file = os.path.join(pendulum_dir, f'inverted_pendulum_tmp.xml')
    tree.write(output_file, encoding='latin-1')

def change_param_double_pendulum(pendulum_dir, xml_file, new_length=0.6, new_gravity=-9.81):

    """
    new_length ~ 0.6
    new_gravity ~ -9.81
    """
    
    xml_file_new = os.path.join(pendulum_dir, xml_file)
    tree = ET.parse(xml_file_new)
    root = tree.getroot()
    for elem in root.iter():
        if elem.tag == 'option':
            attributes = elem.attrib
            attributes['gravity'] = f'1e-5 0 {new_gravity}'

    for elem in root.iter():
        if elem.tag == 'geom':
            attributes = elem.attrib
            if 'fromto' in attributes:
                attributes['fromto'] = f"0 0 0 0 0 {new_length}"
        if elem.tag == 'site':
            attributes = elem.attrib
            if 'pos' in attributes:
                attributes['pos'] = f"0 0 {new_length}"
        if elem.tag == 'body':
            attributes = elem.attrib
            if 'pos' in attributes:
                if attributes['name'] == 'pole2':
                    attributes['pos'] = f"0 0 {new_length}"
                
    output_file = os.path.join(pendulum_dir, f'double_pendulum_tmp.xml')
    tree.write(output_file, encoding='latin-1')


def generate_trajectory(env, random_actions, reset_every=None):
    old_state = env.reset()
    old_state_img = get_image(env)

    support_x = []
    support_y = []
    for ix, action in enumerate(random_actions):
        new_state, _, _, _ = env.step(action)
        new_state_img = get_image(env)

        x_state_action = np.concatenate([old_state, action])
        support_x.append([old_state_img, action, x_state_action])
        support_y.append([new_state_img, new_state])

        old_state = new_state
        old_state_img = new_state_img

        if reset_every is not None:
            #print("resetting")
            if ix % reset_every == 0:
                old_state = env.reset()
                old_state_img = get_image(env)


    return support_x, support_y

def generate_dataset(num_environments, root_dir, seed, file_name_prefix=None):
    data_dir = os.path.join(root_dir, 'data')

    #num_support = 30
    num_support = 50

    num_query = 50

    LENGTH_MIN = 0.6
    LENGTH_MAX = 0.8
    #LENGTH_MAX = 1.4

    GRAVITY_MIN = -9.81
    GRAVITY_MAX = -7.0
    #GRAVITY_MAX = -4.0
    #GRAVITY_MAX = -2.0

    pendulum_dir = os.path.join(root_dir, 'helper', 'pendulum_xml')
    xml_file_single = 'inverted_pendulum_tmp.xml'
    xml_file_double = 'double_pendulum_tmp.xml'

    rng = np.random.default_rng(seed)

    if file_name_prefix is not None:
        file_path = os.path.join(data_dir, f'{file_name_prefix}_single_double_pendulum_seed_{seed}.pkl')
    else:
        file_path = os.path.join(data_dir, f'single_double_pendulum_seed_{seed}.pkl')


    all_tasks = []

    for i in tqdm(range(num_environments)):
        new_length = np.round(rng.uniform(LENGTH_MIN, LENGTH_MAX), 5)
        new_gravity = np.round(rng.uniform(GRAVITY_MIN, GRAVITY_MAX), 5)

        change_param_pendulum(pendulum_dir, xml_file_single, new_length, new_gravity)
        change_param_double_pendulum(pendulum_dir, xml_file_double, new_length, new_gravity)

        env_single = InvertedPendulumEnv(os.path.join(pendulum_dir, xml_file_single))
        env_double = InvertedDoublePendulumEnv(os.path.join(pendulum_dir, xml_file_double))
        env_single.reset()
        env_double.reset()

        random_single_actions = rng.uniform(-3, 3, (num_support, 1))
        random_double_actions = rng.uniform(-3, 3, (num_query, 1))

        support_x, support_y = generate_trajectory(env_single, random_single_actions, reset_every=10)
        query_x, query_y = generate_trajectory(env_double, random_double_actions, reset_every=10)

        all_tasks.append([support_x, support_y, query_x, query_y, [new_length, new_gravity]])

    with open(file_path, 'wb') as f:
        pickle.dump(all_tasks, f)

    print(f"Saved {file_path}")


if __name__ == '__main__':
    root_dir = '/media/gustaf/039f6885-460f-4da2-92d0-1828ceba36e2/function-shift'
    seed = 0
    num_environments = 600

    generate_dataset(num_environments, root_dir, 0, 'multi')
    generate_dataset(num_environments, root_dir, 1, 'multi')


     