import copy
import gc
import numpy as np
import os
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.tensor_utils as TensorUtils
import time
import multiprocessing
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import barrier

from libero.libero.envs import OffScreenRenderEnv, SubprocVectorEnv, DummyVectorEnv
from libero.libero.utils.time_utils import Timer
from libero.libero.utils.video_utils import VideoWriter
from libero.lifelong.utils import *
from libero.lifelong.DistributedMetrics   import SuccessMetric, LossMetric
from libero.lifelong.datasets import EvalDataset

def raw_obs_to_tensor_obs(obs, task_emb, cfg):
    """
    Prepare the tensor observations as input for the algorithm.
    """
    env_num = len(obs)
    if task_emb.dim() == 1:
        data = {
            "obs": {},
            "task_emb": task_emb.repeat(env_num, 1),
        }
    elif task_emb.dim() == 2:
        assert task_emb.shape[0] == env_num
        data = {
            "obs": {},
            "task_emb": task_emb,
        }

    all_obs_keys = []
    for modality_name, modality_list in cfg.data.obs.modality.items():
        for obs_name in modality_list:
            data["obs"][obs_name] = []
        all_obs_keys += modality_list

    for k in range(env_num):
        for obs_name in all_obs_keys:
            data["obs"][obs_name].append(
                ObsUtils.process_obs(
                    torch.from_numpy(obs[k][cfg.data.obs_key_mapping[obs_name]]),
                    obs_key=obs_name,
                ).float()
            )

    for key in data["obs"]:
        data["obs"][key] = torch.stack(data["obs"][key])

    data = TensorUtils.map_tensor(data, lambda x: safe_device(x, device=cfg.device))
    return data


def evaluate_one_task_success(
    cfg, algo, benchmark, task, task_emb, task_id, logger, sim_states=None, task_str=""
):
    """
    Evaluate a single task's success rate
    sim_states: if not None, will keep track of all simulated states during
                evaluation, mainly for visualization and debugging purpose
    task_str:   the key to access sim_states dictionary
    """
    original_multi = multiprocessing.get_start_method(allow_none=True)
    if original_multi != "spawn":
        multiprocessing.set_start_method("spawn", force=True)
    with Timer() as t:
        if cfg.lifelong.algo == "PackNet":  # need preprocess weights for PackNet
            algo = algo.get_eval_algo(task_id)

        algo.eval()

        # initiate evaluation envs
        env_args = {
            "bddl_file_name": os.path.join(
                cfg.bddl_folder, task.problem_folder, task.bddl_file
            ),
            "camera_heights": cfg.data.img_h,
            "camera_widths": cfg.data.img_w,
        }
        init_states_path = os.path.join(
            cfg.init_states_folder, task.problem_folder, task.init_states_file
        )

        if cfg.eval.use_mp:
            env_num = cfg.eval.num_procs
        else:
            env_num = 1

        success_metric = SuccessMetric()
        eval_dataset = EvalDataset(init_states_path, cfg.eval.n_eval)
        eval_dataloader = DataLoader(
            eval_dataset,
            batch_size=env_num,
            num_workers=cfg.eval.num_workers,
            shuffle=False,
            pin_memory=True,
            sampler=DistributedSampler(eval_dataset)
        )


        count = 0
        env_creation = False
        while not env_creation and count < 5:
            try:
                if env_num == 1:
                    env = DummyVectorEnv(
                        [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
                    )
                else:
                    env = SubprocVectorEnv(
                        [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
                    )
                env_creation = True
            except:
                time.sleep(5)
                count += 1
        if count >= 5:
            raise Exception("Failed to create environment")


        num_success = 0
        barrier()
        eval_dataloader.sampler.set_epoch(0)
        logger.log(f"start evaluate task {task_id}")

        for (i, init_states_) in enumerate(eval_dataloader):

            if cfg.eval.use_augmentation:
                if env_num == 1:
                    task_emb = benchmark.get_augmented_task_emb(task_id)
                else:
                    rollout_emb_list = []
                    for _ in range(env_num):
                        rollout_emb = benchmark.get_augmented_task_emb(task_id).unsqueeze(0)
                        rollout_emb_list.append(rollout_emb)
                    task_emb = torch.cat(rollout_emb_list, dim=0)
            else:
                task_emb = benchmark.get_task_emb(task_id)

            env.reset()

            dones = [False] * env_num
            steps = 0
            algo.reset()
            obs = env.set_init_state(init_states_)

            # dummy actions [env_num, 7] all zeros for initial physics simulation
            dummy = np.zeros((env_num, 7))
            for _ in range(5):
                obs, _, _, _ = env.step(dummy)

            if task_str != "":
                sim_state = env.get_sim_state()
                for k in range(env_num):
                    if i * env_num + k < cfg.eval.n_eval and sim_states is not None:
                        sim_states[i * env_num + k].append(sim_state[k])

            while steps < cfg.eval.max_steps:
                steps += 1

                data = raw_obs_to_tensor_obs(obs, task_emb, cfg)
                actions = algo.policy.get_action(data)

                obs, reward, done, info = env.step(actions)

                # record the sim states for replay purpose
                if task_str != "":
                    sim_state = env.get_sim_state()
                    for k in range(env_num):
                        if i * env_num + k < cfg.eval.n_eval and sim_states is not None:
                            sim_states[i * env_num + k].append(sim_state[k])

                # check whether succeed
                for k in range(env_num):
                    dones[k] = dones[k] or done[k]

                if all(dones):
                    break

            # a new form of success record
            for k in range(env_num):
                if i * env_num + k < cfg.eval.n_eval:
                    num_success += int(dones[k])
                    success_metric.update(int(dones[k]), 1)
        barrier()
        success_rate = success_metric.compute()
        env.close()
        gc.collect()

    logger.log(f"evaluate task {task_id} takes {t.get_elapsed_time():.1f} seconds", "highlight")
    multiprocessing.set_start_method(original_multi, force=True)
    return success_rate.item()

def evaluate_success(cfg, algo, benchmark, task_ids, logger, result_summary=None):
    """
    Evaluate the success rate for all task in task_ids.
    """
    algo.eval()
    successes = []
    for i in task_ids:
        task_i = benchmark.get_task(i)
        task_emb = benchmark.get_task_emb(i)
        task_str = f"k{task_ids[-1]}_p{i}"
        curr_summary = result_summary[task_str] if result_summary is not None else None
        success_rate = evaluate_one_task_success(
            cfg, algo, benchmark, task_i, task_emb, i, logger, sim_states=curr_summary, task_str=task_str
        )
        successes.append(success_rate)
    return np.array(successes)


def evaluate_multitask_training_success(cfg, algo, benchmark, task_ids, logger):
    """
    Evaluate the success rate for all task in task_ids.
    """
    algo.eval()
    successes = []
    for i in task_ids:
        task_i = benchmark.get_task(i)
        task_emb = benchmark.get_task_emb(i)
        success_rate = evaluate_one_task_success(cfg, algo, benchmark, task_i, task_emb, i, logger)
        successes.append(success_rate)
        logger.log(f"task {i} success rate: {success_rate:.3f}")
    return np.array(successes)


@torch.no_grad()
def evaluate_loss(cfg, algo, benchmark, datasets):
    """
    Evaluate the loss on all datasets.
    """
    algo.eval()
    losses = []
    for i, dataset in enumerate(datasets):
        if cfg.lifelong.algo == "PackNet":  # need preprocess weights for PackNet
            algo = algo.get_eval_algo(task_id=i)

        test_loss_metric = LossMetric()
        dataloader = DataLoader(
            dataset,
            batch_size=cfg.eval.batch_size,
            num_workers=cfg.eval.num_workers,
            shuffle=False,
            pin_memory=True,
            sampler=DistributedSampler(dataset),
        )

        test_loss = 0
        dataloader.sampler.set_epoch(0)
        for data in dataloader:
            data = TensorUtils.map_tensor(
                data, lambda x: safe_device(x, device=cfg.device)
            )
            loss = algo.policy.compute_loss(data)
            test_loss += loss.item()

        test_loss_metric.update(test_loss, len(dataloader))
        barrier()
        test_loss = test_loss_metric.compute()
        losses.append(test_loss.item())
    return np.array(losses)
