import argparse
import json
from curses.ascii import controlnames
from typing import List, Dict

from tqdm import tqdm
import h5py
import numpy as np
import os
from copy import deepcopy
import torch

import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from DataGenerators.DataGenerator import DataGenerator, seed_everything
from robomimic.envs.env_base import EnvBase
from robomimic.algo import RolloutPolicy

import urllib.request
from config import PERSISTENT_DATA_PATH

def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None, noise=0.0, return_trajectories=False, expert_percentage=1.0):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video,
    and returns the rollout trajectory.
    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
    """
    assert isinstance(env, EnvBase)
    assert isinstance(policy, RolloutPolicy)
    assert not (render and (video_writer is not None))

    if return_trajectories:
        list_of_obs = []
    policy.start_episode()
    obs = env.reset()
    state_dict = env.get_state()

    # hack that is necessary for robosuite tasks for deterministic action playback
    obs = env.reset_to(state_dict)

    results = {}
    video_count = 0  # video frame counter
    total_reward = 0.
    try:
        for step_i in range(horizon):

            # get action from policy
            expert_act = policy(ob=obs)

            if noise > 0:
                expert_act += noise * np.random.uniform(low=-1.0 * noise, high=1.0 * noise, size=expert_act.shape)

            random_act = np.random.uniform(low=-1, high=1, size=expert_act.shape)

            if np.random.rand() < expert_percentage:
                act = expert_act
            else:
                act = random_act

            # play action
            next_obs, r, done, _ = env.step(act)

            # compute reward
            total_reward += r
            success = env.is_success()["task"]

            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                    video_writer.append_data(video_img)
                video_count += 1

            # break if done or if success
            if done or success:
                break

            if return_trajectories:
                list_of_obs.append(deepcopy(obs))

            # update for next iter
            obs = deepcopy(next_obs)
            state_dict = env.get_state()

    except env.rollout_exceptions as e:
        print("WARNING: got rollout exception {}".format(e))

    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))

    if not return_trajectories:
        return stats
    else:
        return stats, list_of_obs




class RobomimicDataGenerator(DataGenerator):
    def __init__(self, env="Robomimic", type="binary_feedback", distribution=1, expert="/lift_ph_low_dim_epoch_1000_succ_100.pth"):

        """

        :param env:
        :param type:
        :param distribution:  1 full expert 0 fully random
        :param expert_path:
        """
        # super().__init__(env, type, distribution)

        self.expert_path = PERSISTENT_DATA_PATH + "/expert/" + expert
        device = TorchUtils.get_torch_device(try_to_use_cuda=True)
        policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=self.expert_path, device=device, verbose=True)
        self.expert = policy
        self.ckpt_dict = ckpt_dict
        env, _ = FileUtils.env_from_checkpoint(
            ckpt_dict=ckpt_dict,
            render=False, # we won't do on-screen rendering in the notebook
            render_offscreen=True, # render to RGB images for video
            verbose=True,

        )
        self.env = env
        self.type = type
        self.distribution = distribution
    def sample_array_of_states(self, number=2000, cutoff_length=400, seed_list = []) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        expert_action_percentage = self.distribution
        total_state_list = []
        total_length = 0
        for seed in seed_list:
            seed_everything(seed)
            stats, list_of_obs = rollout(policy=self.expert, env=self.env, horizon=cutoff_length, render=False, video_writer=None, noise=0, expert_percentage=expert_action_percentage, return_trajectories=True)
            total_length += len(list_of_obs)
            print(stats)
            total_state_list.append(list_of_obs)
            if total_length >= number:
                break
        self.state_array = np.concatenate(total_state_list)
        return self.state_array



    def sample_array_of_trajectories(self, number=48) -> List[Dict]:
        pass

    def sample_data(self, number=48, cutoff_length=100, seed_list=[], threshold_distance=0.28) -> List[Dict]:
        ret_list = []

        self.sample_array_of_states(number, cutoff_length, seed_list=seed_list)

        bar = tqdm(self.state_array)

        for state in bar:
            expert_actions = self.get_expert_actions(state)
            expert_action = expert_actions[0]

            random_action = np.random.uniform(low=-1, high=1, size=expert_action.shape)
            distance = np.linalg.norm(random_action - expert_action, ord=2)
            random_action_good_flag = True
            if distance > threshold_distance:
                random_action_good_flag = False

            if self.type == "binary_feedback":
                ret_list.append({"state": state, "action": expert_action, "feedback": 1, "expert_actions": expert_actions})
                ret_list.append({"state": state, "action": random_action, "feedback": 1 if random_action_good_flag else -1, "expert_actions": expert_actions})
            elif self.type == "preference":
                if random_action_good_flag:
                    preference = 0
                    ret_list.append({"state": state, "action1": expert_action, "action2": random_action, "feedback": 0, "expert_actions": expert_actions})

                elif np.random.rand() < 0.5:
                    ret_list.append({"state": state, "action1": expert_action, "action2": random_action, "feedback": 1, "expert_actions": expert_actions})
                else:
                    ret_list.append({"state": state, "action1": random_action, "action2": expert_action, "feedback": -1})
            elif self.type == "action_advising":
                ret_list.append({"state": state, "feedback": expert_actions, "expert_actions": expert_actions})

            elif self.type == "delta_action":

                wrong_action_generated = False
                wrong_action_positive_offset = True
                while not wrong_action_generated:
                    wrong_index = np.random.choice(expert_action.shape[0])
                    wrong_action = deepcopy(expert_action)
                    print(wrong_action)
                    if np.random.rand() < 0.5:
                        wrong_action[wrong_index] += threshold_distance
                        wrong_action_positive_offset = True
                    else:
                        wrong_action[wrong_index] -= threshold_distance
                        wrong_action_positive_offset = False
                    if wrong_action[wrong_index] >= -1 and wrong_action[wrong_index] <= 1:
                        wrong_action_generated = True

                delta_actions = []
                correct_or_not = []
                correct_answer_index = 0
                for i in range(expert_action.shape[0]):
                    for offset in [+1, -1]:
                        delta_action = deepcopy(wrong_action)
                        delta_action[i] += offset * threshold_distance
                        # print(delta_action)
                        delta_actions.append(delta_action)
                        correct_answer = True if i == wrong_index and delta_action[i] == expert_action[i] else False
                        correct_or_not.append(correct_answer)
                        if correct_answer:
                            correct_answer_index = len(correct_or_not) - 1

                ret_list.append({"state": state, "feedback": correct_answer_index, "expert_actions": expert_actions, "delta_actions": delta_actions, "wrong_action": wrong_action})


        return ret_list


    def get_expert_actions(self, state) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """

        return [self.expert(state)]
    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def get_expert_value(self, state) -> float:
        return NotImplemented


if __name__ == "__main__":
    for i in [(1, "binary_feedback"), (0.5, "binary_feedback"), (0, "binary_feedback")]:
        mgdg = RobomimicDataGenerator("robomimic", i[1], i[0])
        data = mgdg.sample_data(2000, seed_list=list(range(100)))
        np.save(PERSISTENT_DATA_PATH + f"/robomimic/robomimic{i[1]}_{i[0]}.npy", data)

    for i in [(0, "preference"), (0.5, "preference"), (1, "preference")]:
        mgdg = RobomimicDataGenerator("robomimic", i[1], i[0])
        data = mgdg.sample_data(2000, seed_list=list(range(100)))
        np.save(PERSISTENT_DATA_PATH + f"/robomimic/robomimic{i[1]}_{i[0]}.npy", data)

    for i in [(0, "action_advising"), (0.5, "action_advising"), (1, "action_advising")]:
        mgdg = RobomimicDataGenerator("robomimic", i[1], i[0])
        data = mgdg.sample_data(2000, seed_list=list(range(100)))
        np.save(PERSISTENT_DATA_PATH + f"/robomimic/robomimic{i[1]}_{i[0]}.npy", data)