import pickle
import torch
import numpy as np
import metaworld
from tqdm import tqdm
from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
from stable_baselines3 import SAC
from envs.mujoco_control_envs.mujoco_control_envs import HalfCheetahVelEnv
from envs.darkroom.darkroom_env import DarkroomEnv
from matplotlib import pyplot as plt

def cheetah_vel_normalize_rewards(rewards, min_reward, max_reward):
    return (rewards - min_reward) / (max_reward - min_reward)

class Controller:
    def __init__(self, model, device, args, mode='offline'):
        self.model = model
        self.device = device
        self.args = args
        self.horizon = args.horizon
        if mode == 'offline':
            self.deploy = self.deploy_offline
        elif mode == 'online':
            self.deploy = self.deploy_online
        elif mode == 'optimal':
            self.deploy = self.deploy_optimal
        else:
            raise ValueError(f"Invalid mode: {mode}")

    def convert_to_tensor(self, x):
        return torch.tensor(np.asarray(x)).float().to(self.device)
    
    def set_env(self, env_traj=None, task_id=None):
        if self.args.env == 'ml1_pick_place':
            if env_traj is not None:
                task_id = env_traj['task_id']
            ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
            env = ml1.train_classes['pick-place-v2']()
            task = ml1.train_tasks[task_id]
            env.set_task(task)
            env.max_path_length = self.horizon
        elif self.args.env == 'cheetah_vel':
            if env_traj is not None:
                task_id = env_traj['task_id']            
            task_path = f'{self.args.model_ckpt_path}/task_{task_id}/config_cheetah_vel_task{task_id}.pkl'
            with open(task_path, 'rb') as f:
                task_info = pickle.load(f)
            env = HalfCheetahVelEnv(task_info, include_goal = False)   
            self.min_reward = task_info[0]['initial_reward']
            self.max_reward = task_info[0]['optimal_reward']
        elif self.args.env == 'darkroom':
            if env_traj is not None:
                goal = env_traj['goal']
            else:
                goal = np.load(self.args.darkroom_goal_path)[task_id]
            env = DarkroomEnv(dim=self.args.dim, goal=goal, horizon=self.horizon)
        self.env = env

    def set_batch(self, batch):
        self.batch = batch

    def act(self, state):
        state = torch.tensor(np.array(state)).float().to(self.model.device)
        self.batch['query_states'] = state[None, :]
        action = self.model(self.batch).cpu().detach().numpy()
        return action.squeeze()
    
    def act_optimal(self, state):
        if self.args.env == 'ml1_pick_place':
            return self.policy.get_action(state)
        elif self.args.env == 'cheetah_vel':
            action, _ = self.policy.predict(state, deterministic=True)
        elif self.args.env == 'darkroom':
            return self.env.opt_action(state)
        return action
    
    def _deploy(self):
        states = []
        actions = []
        next_states = []
        rewards = []    

        if self.args.env == 'ml1_pick_place' or self.args.env == 'darkroom':
            state = self.env.reset()
        else:
            state, _ = self.env.reset()

        for _ in range(self.horizon):
            action = self.act(state)
            next_state, reward, done, info = self.env.step(action)

            states.append(state)
            actions.append(action)
            next_states.append(next_state)
            rewards.append(reward)
            state = next_state

        states = np.array(states)
        actions = np.array(actions)
        next_states = np.array(next_states)
        rewards = np.array(rewards)
        return states, actions, next_states, rewards
    
    def deploy_offline(self, env_traj):
        self.set_env(env_traj=env_traj)
        batch = {
            'context_states': self.convert_to_tensor(env_traj['context_states'][None, :, :]),
            'context_actions': self.convert_to_tensor(env_traj['context_actions'][None, :, :]),
            'context_next_states': self.convert_to_tensor(env_traj['context_next_states'][None, :, :]),
            'context_rewards': self.convert_to_tensor(env_traj['context_rewards'][None, :, None]),
        }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        return rewards.sum()

    def deploy_online(self, task_id, Heps):
        rewards_all = []
        self.set_env(task_id=task_id)
        batch = {
            'context_states': torch.empty((1, 0, self.args.state_dim)).float().to(self.device),
            'context_actions': torch.empty((1, 0, self.args.action_dim)).float().to(self.device),
            'context_next_states': torch.empty((1, 0, self.args.state_dim)).float().to(self.device),
            'context_rewards': torch.empty((1, 0, 1)).float().to(self.device),
        }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        rewards_all.append(rewards.sum())

        if self.args.env == 'cheetah_vel':
            rewards = cheetah_vel_normalize_rewards(rewards, self.min_reward, self.max_reward)

        for _ in range(1, Heps):
            batch = {
                'context_states': self.convert_to_tensor(states[None, :, :]),
                'context_actions': self.convert_to_tensor(actions[None, :, :]),
                'context_next_states': self.convert_to_tensor(next_states[None, :, :]),
                'context_rewards': self.convert_to_tensor(rewards[None, :, None]),
            }
            self.set_batch(batch)
            states, actions, next_states, rewards = self._deploy()
            rewards_all.append(rewards.sum())
        return rewards_all
    
    def deploy_optimal(self, task_id):
        self.set_env(task_id=task_id)
        if self.args.env == 'ml1_pick_place':
            self.policy = SawyerPickPlaceV2Policy()
        elif self.args.env == 'cheetah_vel':
            model_ckpt_path = f'{self.args.model_ckpt_path}/task_{task_id}'
            self.policy = SAC.load(f'{model_ckpt_path}/sac_checkpoint_task_{task_id}_best')
        # elif self.args.env == 'darkroom': the env has optimal policy so no need to deploy
        self.act = self.act_optimal
        states, actions, next_states, rewards = self._deploy()
        return rewards.sum()


def online(model, device, args):
    controller = Controller(model, device, args, mode='online')
    rewards_all = []
    for task_id in tqdm((args.train_tasks + args.test_tasks)): #FIXME: change for darkroom
        rewards = controller.deploy(task_id, args.Heps)
        rewards_all.append(rewards)
    rewards_all = np.array(rewards_all)     
    return  rewards_all # in shape of (num_tasks, Heps)


def offline(env_trajs, model, device, args):
    controller = Controller(model, device, args, mode='offline')
    rewards_all = []
    for env_traj in tqdm(env_trajs):
        rewards = controller.deploy(env_traj)
        rewards_all.append(rewards.sum())
    rewards_all = np.array(rewards_all)
    return rewards_all # in shape of (num_tasks)


def optimal(model, device, args):
    controller = Controller(model, device, args, mode='optimal')
    rewards_optimal = []
    for task_id in tqdm((args.train_tasks + args.test_tasks)):
        rewards_optimal.append(controller.deploy(task_id))
    rewards_optimal = np.array(rewards_optimal)
    return rewards_optimal


def evaluation(env_trajs, model, device, epoch, args, writer=None):
    offline_rewards_random = offline(env_trajs[0], model, device, args)
    offline_rewards_expert = offline(env_trajs[1], model, device, args)
    online_rewards = online(model, device, args)
    optimal_rewards = optimal(model, device, args)
    if writer is not None:
        # draw online rewards according to Heps
        for i in range(online_rewards.shape[0]):
            plt.plot(online_rewards[i], color='blue', alpha=0.2)
        plt.plot(online_rewards.mean(axis=0), color='blue', label='Online')
        plt.fill_between(np.arange(args.Heps), 
                         online_rewards.mean(axis=0) - online_rewards.std(axis=0), 
                         online_rewards.mean(axis=0) + online_rewards.std(axis=0), alpha=0.2)
        plt.legend()
        plt.xlabel('Episodes')
        plt.ylabel('Reward')
        plt.title('Online Rewards')
        writer.add_figure('Evaluation/online', plt.gcf(), epoch)
        plt.clf()

        # draw offline rewards bar chart (mean, std, and variance)
        baselines = {
            'Optimal': optimal_rewards,
            'Learner(random)': offline_rewards_random,
            'Learner(expert)': offline_rewards_expert,
        }
        baselines_means = {k: np.mean(v) for k, v in baselines.items()}
        baselines_stds = {k: np.std(v) for k, v in baselines.items()}
        colors = plt.cm.viridis(np.linspace(0, 1, len(baselines_means)))
        plt.bar(baselines_means.keys(), baselines_means.values(), color=colors, yerr=list(baselines_stds.values()), capsize=5)
        plt.ylabel('Average Return')
        plt.title('Offline Rewards')
        writer.add_figure('Evaluation/offline', plt.gcf(), epoch)
        plt.clf()

    return online_rewards, offline_rewards_random, offline_rewards_expert, optimal_rewards
