import logging
from collections import deque, defaultdict
from pathlib import Path
from omegaconf import OmegaConf
import einops
import gym
from gym.wrappers import RecordVideo
import hydra
import numpy as np
import torch
from models.action_ae.generators.base import GeneratorDataParallel
from models.latent_generators.latent_generator import LatentGeneratorDataParallel
import utils
import wandb
import os
from tqdm import tqdm

from utils import get_goal_name_in_order
from matplotlib import pyplot as plt
from dataloaders.trajectory_loader import MetaWorldVideoTrajectoryDataset
from metaworld.envs.mujoco.sawyer_xyz.v2 import TestMTEnvV2
import random
import time

ALL_TASKS = ["bottom burner","top burner","light switch","slide cabinet","hinge cabinet","microwave","kettle"]

class Workspace:
    def __init__(self, cfg):
        self.work_dir = Path.cwd()
        print("Saving to {}".format(self.work_dir))
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        utils.set_seed_everywhere(cfg.seed)
        self.dataset = hydra.utils.call(
            cfg.env.dataset_fn,
            train_fraction=cfg.train_fraction,
            random_seed=cfg.seed,
            device=self.device,
        )
        self.train_set, self.eval_set = self.dataset
        if self.cfg.eval_on == "eval":
            self.test_set = self.eval_set.dataset
        elif self.cfg.eval_on == "train":
            self.test_set = self.train_set.dataset
        self.goal_seq_lenth, self.obs_seq_lenth = self.test_set.get_seq_length()
        self.window_size = cfg.window_size
        self.env = self._init_env()
        # Create the model
        self.action_ae = None
        self.obs_encoding_net = None
        self.state_prior = None
        if not self.cfg.lazy_init_models:
            self._init_action_ae()
            self._init_obs_encoding_net()
            self._init_state_prior()

        self.wandb_run = wandb.init(
            dir=self.work_dir,
            project=cfg.project,
            config=OmegaConf.to_container(cfg, resolve=True),
        )
        logging.info("wandb run url: %s", self.wandb_run.get_url())
        self.epoch = 0
        self.load_snapshot()
        print(self.action_ae)
        print(self.obs_encoding_net)
        print(self.state_prior)

        # Set up rolling window contexts.
        self.img_context = deque(maxlen=self.window_size)
        self.obs_context = deque(maxlen=self.window_size)

        if self.cfg.flatten_obs:
            self.env = gym.wrappers.FlattenObservation(self.env)

        if self.cfg.plot_interactions:
            self._setup_plots()

        if self.cfg.start_from_seen:
            self._setup_starting_state()

        self._setup_action_sampler()

    def _init_env(self):
        env = gym.make("kitchen-all-v0")
        if env.class_name() == "TimeLimit":
            env = env.env
            env = gym.wrappers.TimeLimit(env, self.goal_seq_lenth)
        env.seed(self.cfg.seed)
        return env

    def _init_action_ae(self):
        if self.action_ae is None:  # possibly already initialized from snapshot
            self.action_ae = hydra.utils.instantiate(
                self.cfg.action_ae, _recursive_=False
            ).to(self.device)
            if self.cfg.data_parallel:
                self.action_ae = GeneratorDataParallel(self.action_ae)

    def _init_obs_encoding_net(self):
        if self.obs_encoding_net is None:  # possibly already initialized from snapshot
            self.obs_encoding_net = hydra.utils.instantiate(self.cfg.encoder)
            self.obs_encoding_net = self.obs_encoding_net.to(self.device)
            if self.cfg.data_parallel:
                self.obs_encoding_net = torch.nn.DataParallel(self.obs_encoding_net)

    def _init_state_prior(self):
        if self.state_prior is None:  # possibly already initialized from snapshot
            self.state_prior = hydra.utils.instantiate(
                self.cfg.state_prior,
                latent_dim=self.action_ae.latent_dim,
                vocab_size=self.action_ae.num_latents,
            ).to(self.device)
            if self.cfg.data_parallel:
                self.state_prior = LatentGeneratorDataParallel(self.state_prior)
            self.state_prior_optimizer = self.state_prior.get_optimizer(
                learning_rate=self.cfg.lr,
                weight_decay=self.cfg.weight_decay,
                betas=tuple(self.cfg.betas),
            )

    def _setup_action_sampler(self):
        def sampler(actions):
            idx = np.random.randint(len(actions))
            return actions[idx]

        self.sampler = sampler

    def _get_finised_task(self, complete_tasks, info):
        for task in ALL_TASKS:
            if info[f"finish_{task}"] == 1:
                if not task in complete_tasks:
                    complete_tasks.append(task)

    def _get_matched_goals(self, complete_tasks, expected_tasks):
        return len(set(complete_tasks).intersection(expected_tasks))

    def _show_img(self, img):
        plt.imshow(img)
        plt.show()

    def run_single_episode(self, goal, goal_mask, expected_tasks, expert_actions):
        self.obs_context.clear()
        self.img_context.clear()

        obs = self.env.reset()
        img = self.env.render(mode='rgb_array').copy()
        done = False
        action = self._get_action(obs, img, goal=goal, goal_mask=goal_mask, sample=True)

        for i in range(1, self.cfg.num_eval_steps):
            if done:
                break
            obs, _, done, info = self.env.step(action)
            img = info['images'].copy()
            action = self._get_action(obs, img, goal=goal, goal_mask=goal_mask, sample=True)

        complete_tasks = self.env.all_completions
        reward = len(complete_tasks)
        result = self._get_matched_goals(complete_tasks, expected_tasks)
        logging.info(f"complete_tasks: {complete_tasks}")
        logging.info(f"Reward: {reward}")
        logging.info(f"Result: {result}")

        return reward, result

    def _get_action(self, obs, img, goal, goal_mask, sample=True):
        with utils.eval_mode(
            self.action_ae, self.obs_encoding_net, self.state_prior, no_grad=True
        ):
            obs = torch.from_numpy(obs).float().to(self.cfg.device)
            img = torch.from_numpy(img).float().to(self.cfg.device)
            obs = obs[:9]
            obs = obs.unsqueeze(0)
            img = img.unsqueeze(0)
            # Now, add to history. This automatically handles the case where
            # the history is full.
            self.obs_context.append(obs)
            self.img_context.append(img)
            if self.cfg.goal_conditional in ["video", "future"]:
                goal = goal.to(device=self.cfg.device).float().unsqueeze(0)
                goal_mask = goal_mask.to(device=self.cfg.device).unsqueeze(0)
            else:
                print(f"goal conditional must be in [video, future]")
                exit()
            obs_rep = torch.stack(tuple(self.obs_context), dim=1)
            img_rep = torch.stack(tuple(self.img_context), dim=1)
            action_latents = self.state_prior.generate_latents(
                    self_obs_rep=obs_rep,
                    img_obs_rep = img_rep,
                    obs_mask=torch.ones_like(obs_rep).mean(dim=-1),
                    goal=goal,
                    goal_mask = goal_mask
                )
            actions = self.action_ae.decode_actions(
                latent_action_batch=action_latents,
            )
            actions = actions.cpu().numpy()
            # Take the last action; this assumes that SPiRL (CVAE) already selected the first action
            actions = actions[:, -1, :]
            # Now action shape is (batch_size, action_dim)
            if sample:
                actions = self.sampler(actions)
            return actions
    
    def run(self):
        rewards = []
        results = []
        results_dict = {}
        for i in range(len(self.test_set)):
            _, expert_actions, onehot_goals, _, video, video_mask = self.test_set[i]
            task_key = get_goal_name_in_order(onehot_goals)
            if task_key in results_dict:
                if len(results_dict[task_key]) >= self.cfg.num_eval_eps * self.cfg.num_eval_goals:
                    print(f"cur goal: {task_key}, avg performance {np.mean(results_dict[task_key])}")
                    continue
            for j in range(self.cfg.num_eval_eps):
                reward, result = self.run_single_episode(goal=video, goal_mask=video_mask, expected_tasks=task_key, expert_actions=expert_actions.numpy())
                rewards.append(reward)
                results.append(result)
                try:
                    results_dict[task_key].append(result)
                except:
                    results_dict[task_key] = [result]
                print(f"cur iter goal {task_key}, performance {result}")
        self.env.close()
        for key in results_dict:
            results_dict[key] = np.mean(results_dict[key])
        logging.info(results_dict)
        return rewards, results

    @property
    def snapshot(self):
        return Path(self.cfg.model_path)

    def load_snapshot(self):
        keys_to_load = ["action_ae", "obs_encoding_net", "state_prior"]
        logging.info(f"Load snapshot from: {self.snapshot}")
        with self.snapshot.open("rb") as f:
            payload = torch.load(f, map_location=self.device)
        loaded_keys = []
        for k, v in payload.items():
            if k in keys_to_load:
                loaded_keys.append(k)
                self.__dict__[k] = v.to(self.cfg.device)

        if len(loaded_keys) != len(keys_to_load):
            raise ValueError(
                "Snapshot does not contain the following keys: "
                f"{set(keys_to_load) - set(loaded_keys)}"
            )
