from envir.mujoco_env import MujocoEnv as Env
from model.option_policy import OptionPolicy
from model.MHA_option_policy_critic import MHAOptionPolicy

import torch
import random
import numpy as np
from copy import deepcopy
from functools import partial
from multiprocessing import Pipe, Process
from option_nmf_generator import OptionNMFGenerator
import pdb
import wandb

def env_worker(remote, env_fn):
    # Make environment
    env = env_fn.x()
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            action = data
            next_s, reward, done = env.step(action)
            remote.send({"next_state": next_s, "reward": reward, "done": done})
        elif cmd == "reset":
            cnt = data['context']
            is_expert = data['is_expert']
            init_s = env.reset(cnt, is_expert)
            remote.send({"state": init_s})
        elif cmd == "render_my":
            mode = data['mode']
            img = env.render_my(mode)
            remote.send({"img": img})

        elif cmd == "close":
            env.close()
            remote.close()
            break
        elif cmd == "set_seed":
            seed = data
            env.seed(seed)
        elif cmd == "sample_context":
            cnt = env.sample_context()
            remote.send({"context": cnt})
        elif cmd == "get_gt_option":
            option = env.get_gt_option()
            remote.send({"option": option})
        else:
            raise NotImplementedError

class CloudpickleWrapper():
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """
    def __init__(self, x):
        self.x = x
    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)
    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)

def env_fn(env_id):
    # return gym.make(env_id)
    temp_env = Env(env_id)
    temp_env.init()
    return temp_env

class EnvWrapper(object):
    def __init__(self, seed, env_id, env_num):
        self.env_num = env_num

        self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.env_num)])
        self.ps = [Process(target=env_worker, args=(worker_conn, CloudpickleWrapper(partial(env_fn, env_id)))) for worker_conn in self.worker_conns]

        for p in self.ps:
            p.daemon = True
            p.start()

        # for idx in range(self.env_num):
        #     # temp_seed = seed + idx + 1 # TODO
        #     self.parent_conns[idx].send(('set_seed', seed))

    def close(self):
        for parent_conn in self.parent_conns:
            parent_conn.send(("close", None))

    def sample_context(self):
        for parent_conn in self.parent_conns:
            parent_conn.send(("sample_context", None))

        context_list = []
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            context_list.append(data['context'])

        return context_list

    def reset(self, context_list, is_expert):
        for idx, parent_conn in enumerate(self.parent_conns):
            parent_conn.send(("reset", {'context': context_list[idx], 'is_expert': is_expert}))

        init_states = []
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            init_states.append(data["state"])

        return np.array(init_states)

    def render_my(self, mode):
        for idx, parent_conn in enumerate(self.parent_conns):
            parent_conn.send(("render_my", {'mode': mode}))

        imgs = []
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            imgs.append(data["img"])

        return np.array(imgs)

    def get_gt_option(self):
        for idx, parent_conn in enumerate(self.parent_conns):
            parent_conn.send(("get_gt_option", None))

        gt_options = []
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            gt_options.append(data["option"]) # (env_num, 1)

        return np.array(gt_options)

    def step(self, action_array, done_vec, s):
        for idx, parent_conn in enumerate(self.parent_conns):
            if not done_vec[idx]:
                parent_conn.send(("step", action_array[idx]))

        next_s = np.zeros_like(s, dtype=np.float32)
        r = np.zeros((self.env_num, 1), dtype=np.float32)
        done = [True for _ in range(self.env_num)]

        for idx, parent_conn in enumerate(self.parent_conns):
            if not done_vec[idx]:
                data = parent_conn.recv()
                next_s[idx] = data['next_state']
                r[idx] = data['reward']
                done[idx] = data['done']

        return next_s, r, done


def no_option_loop(env, policy, is_expert, fixed, task_list=None, contain_context=False, option_gt=None):
    env_num = env.env_num
    with torch.no_grad():

        if task_list is not None: # when testing, we will specify the list of tasks
            context_list = []
            for i in range(env_num):
                context_list.append(random.choice(task_list))
        else:
            context_list = env.sample_context()

        cnt_dim = len(context_list[0])

        s = env.reset(context_list, is_expert=is_expert) # (env_num, s_dim)
        horizons = [0 for _ in range(env_num)]
        done_vec = [False for _ in range(env_num)]
        # print("1: ", s)
        s_list, a_list, r_list = [], [], []
        while True:
            st = torch.as_tensor(s, dtype=torch.float32, device=policy.device) # (env_num, s_dim)
            if not contain_context:
                st = st[:, :-cnt_dim]
            at = policy.sample_action(st, fixed=fixed).detach() # (env_num, a_dim)
            at = at.cpu().numpy()
            next_s, rewards, done_vec = env.step(at, done_vec, s) # rewards: (env_num, 1)

            if not contain_context:
                s_add = (s.copy())[:, :-cnt_dim]
            else:
                s_add = s.copy()

            s_list.append(s_add)
            a_list.append(at.copy())
            r_list.append(rewards.copy())

            s = next_s

            for idx in range(len(done_vec)):
                if not done_vec[idx]:
                    horizons[idx] += 1

            if np.array(done_vec).all():
                break

        rets = []
        for e_id in range(env_num):
            s_array, a_array, r_array = [], [], []
            temp_horizon = horizons[e_id] + 1
            for t_id in range(temp_horizon):
                s_array.append(torch.as_tensor([s_list[t_id][e_id]], dtype=torch.float32, device=policy.device))
                a_array.append(torch.as_tensor([a_list[t_id][e_id]], dtype=torch.float32, device=policy.device))
                r_array.append(torch.as_tensor([r_list[t_id][e_id]], dtype=torch.float32, device=policy.device))
            a_array = torch.cat(a_array, dim=0)
            s_array = torch.cat(s_array, dim=0)
            r_array = torch.cat(r_array, dim=0)
            # print("1: ", s_array.shape, a_array.shape, r_array.shape)
            rets.append((s_array, a_array, r_array))

        trans_num = np.sum(horizons)

    return rets, trans_num


def option_loop(env, policy, is_expert, fixed, task_list=None, contain_context=False, 
                option_gt=False, option_nmf=False, option_nmf_generator=None, render=False, is_eval=False):
    env_num = env.env_num

    with torch.no_grad():
        if task_list is not None: # when testing, we will specify the list of tasks
            context_list = []
            for _ in range(env_num):
                context_list.append(random.choice(task_list))
        else:
            context_list = env.sample_context()
        cnt_dim = len(context_list[0])
        # print("Here: ", cnt_dim)
        s = env.reset(context_list, is_expert=is_expert)  # (env_num, s_dim)
        horizons = [0 for _ in range(env_num)]
        done_vec = [False for _ in range(env_num)]
        # print("1: ", s)
        s_list, a_list, r_list, c_list = [], [], [], []

        if option_gt:
            ct = torch.as_tensor(env.get_gt_option(), device=policy.device, dtype=torch.long).unsqueeze(1) # (env_num, 1)
        if option_nmf:
            ct = torch.as_tensor(option_nmf_generator.get_current_option(s[:, :-cnt_dim]), device=policy.device, dtype=torch.long).unsqueeze(1)
        else:
            ct = torch.empty(env_num, 1, dtype=torch.long, device=policy.device).fill_(policy.dim_c)
        c_list.append(ct.unsqueeze(1))

        # pdb.set_trace()
        if render:
            frames = [[] for _ in range(env_num)]
            task_idx = np.argmax(s[:, -cnt_dim:], axis=1) # (env_num, )
        else:
            frames = None
        while True:
            st = torch.as_tensor(s, dtype=torch.float32, device=policy.device) # (env_num, s_dim)
            if not contain_context:
                st = st[:, :-cnt_dim]
            if option_gt:
                ct = torch.as_tensor(env.get_gt_option(), device=policy.device, dtype=torch.long).unsqueeze(1) # (env_num, 1)
            elif option_nmf:
                ct = torch.as_tensor(option_nmf_generator.get_current_option(s[:, :-cnt_dim]), device=policy.device, dtype=torch.long).unsqueeze(1)
            else:
                ct = policy.sample_option(st, ct, fixed=fixed).detach() # (env_num, 1)
            at = policy.sample_action(st, ct, fixed=fixed).detach() # (env_num, a_dim)
            at = at.cpu().numpy()
            next_s, rewards, done_vec = env.step(at, done_vec, s) # rewards: (env_num, 1)

            if render:
                # pdb.set_trace()
                # (env_num, h, w, 3) -> (env_num, 3, h, w)
                all_imgs = np.transpose(env.render_my(mode='rgb_array'), (0, 3, 1, 2))
                for idx in range(env_num):
                    frames[idx].append(all_imgs[idx])
            if not contain_context:
                s_add = (s.copy())[:, :-cnt_dim]
            else:
                s_add = s.copy()

            s_list.append(s_add)
            a_list.append(at.copy())
            r_list.append(rewards.copy())
            c_list.append(ct.unsqueeze(1))

            s = next_s

            for idx in range(len(done_vec)):
                if not done_vec[idx]:
                    horizons[idx] += 1

            if np.array(done_vec).all():
                break
        
        # if render:
        #     pdb.set_trace()
        #     eval_task = "test" if is_eval else "train"
        #     for idx in range(env_num):
        #         success =  "success" if horizons[idx] < 499 else "fail"
        #         wandb.log({f"video/{eval_task}/task {task_idx[idx]}/{success}": wandb.Video(np.array(frames[idx]), fps=30, format="gif")})
        #         print(f"=> Log video {eval_task}/task {task_idx[idx]}/{success} to wandb")
        if option_nmf == 1:
            option_nmf_generator.reset()
        rets = []
        for e_id in range(env_num):
            s_array, a_array, r_array, c_array = [], [], [], []
            temp_horizon = horizons[e_id] + 1
            for t_id in range(temp_horizon):
                # pdb.set_trace()
                s_array.append(torch.as_tensor(np.array([s_list[t_id][e_id]]), dtype=torch.float32, device=policy.device))
                a_array.append(torch.as_tensor(np.array([a_list[t_id][e_id]]), dtype=torch.float32, device=policy.device))
                r_array.append(torch.as_tensor(np.array([r_list[t_id][e_id]]), dtype=torch.float32, device=policy.device))
                c_array.append(torch.as_tensor(np.array(c_list[t_id][e_id].cpu()), dtype=torch.long, device=policy.device))
            c_array.append(torch.as_tensor(np.array(c_list[temp_horizon][e_id].cpu()), dtype=torch.long, device=policy.device))

            a_array = torch.cat(a_array, dim=0)
            s_array = torch.cat(s_array, dim=0)
            r_array = torch.cat(r_array, dim=0)
            c_array = torch.cat(c_array, dim=0)

            # print("1: ", a_array.shape, s_array.shape, r_array.shape, c_array.shape)

            rets.append((s_array, c_array, a_array, r_array))

        trans_num = np.sum(horizons)

    return rets, trans_num, frames


class VecSampler(object):
    def __init__(self, seed, env_id, env_num, policy, is_expert=False, task_list=None, contain_context=False, 
                 option_gt=False, task_list_target=None, option_nmf=False, option_nmf_generator_base_array=None):
        self.vec_env = EnvWrapper(seed, env_id, env_num)
        self.env_num = env_num
        self.is_expert = is_expert
        self.contain_context = contain_context
        self.task_list = task_list
        self.task_list_target = task_list_target
        self.policy = deepcopy(policy)
        self.option_gt = option_gt
        self.option_nmf = option_nmf
        if option_nmf:
            self.option_nmf_generator = OptionNMFGenerator(option_nmf_generator_base_array)
        else:
            self.option_nmf_generator = None
        if isinstance(policy, OptionPolicy) or isinstance(policy, MHAOptionPolicy):
            self.loop_func = option_loop
        else:
            self.loop_func = no_option_loop

    def filter_demo(self, sa_array):
        print("No filters are adopted.")
        return sa_array

    def collect(self, policy_param, n_sample, fixed=False, render=False):
        # n_sample < 0 for testing, denoting the number of trajs; n_sample > 0 for training, denoting the number of trans
        self.policy.load_state_dict(policy_param)
        counter = n_sample
        rets = []

        # When training, use all train_and_test tasks (in self.task_list)
        # When evaluating, use ONLY test tasks (in self.task_list_target)
        # pdb.set_trace()
        is_eval = (counter < 0)
        if counter > 0:
            while counter > 0:
                trajs, trans_num, frames = self.loop_func(self.vec_env, self.policy, self.is_expert, fixed=fixed,
                                                  task_list=self.task_list, contain_context=self.contain_context, 
                                                  option_gt=self.option_gt, option_nmf=self.option_nmf, 
                                                  option_nmf_generator=self.option_nmf_generator, render=render, is_eval=False)
                
                
                rets.extend(trajs)
                counter -= trans_num
        else:
            # assert self.task_list is not None
            while counter < 0: # only used for testing, so don't need repeated sampling
                trajs, _, frames = self.loop_func(self.vec_env, self.policy, self.is_expert, fixed=fixed,
                                          task_list=self.task_list_target, contain_context=self.contain_context, 
                                          option_gt=self.option_gt, option_nmf=self.option_nmf, 
                                          option_nmf_generator=self.option_nmf_generator, render=render, is_eval=True)
                render = False # only render once
                rets.extend(trajs)
                counter += self.env_num
        if render:
            pdb.set_trace()
            eval_task = "test" if is_eval else "train"
            for idx in range(self.env_num):
                success =  "success" if len(frames[idx]) < 499 else "fail"
                wandb.log({f"video/{eval_task}/{success}": wandb.Video(np.array(frames[idx]), fps=8, format="gif")})
                print(f"=> Log video {eval_task}/{success} to wandb")


        return rets
