import pickle
import torch
import numpy as np
import metaworld
from tqdm import tqdm
from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
from stable_baselines3 import SAC
import sys
sys.path.append('..')
from envs.mujoco_control_envs.mujoco_control_envs import HalfCheetahVelEnv
from envs.darkroom.darkroom_env import DarkroomEnv
from matplotlib import pyplot as plt


class Controller:
    def __init__(self, model, device, args, mode='offline'):
        self.model = model
        self.device = device
        self.args = args
        self.horizon = args.horizon
        self.prompt_K = 5
        if mode == 'offline':
            self.deploy = self.deploy_offline
        elif mode == 'online':
            self.deploy = self.deploy_online
        elif mode == 'optimal':
            self.deploy = self.deploy_optimal
        else:
            raise ValueError(f"Invalid mode: {mode}")

    def convert_to_tensor(self, x):
        return torch.tensor(np.asarray(x)).float().to(self.device)

    # only prompt rewards need to be calculated
    def calc_cum_rewards(self, rewards): # rewards in shape of (1, K, 1)
        gs = torch.full((1, self.prompt_K, 1), self.args.gamma).to(self.device)
        d_gs = torch.cumprod(gs, dim=1)
        d_r = rewards * d_gs
        cdr = torch.cumsum(d_r.flip(1), dim=1).flip(1)
        cdr = cdr/d_gs
        return cdr    
    
    def set_env(self, env_traj=None, task_id=None):

        if self.args.env == 'ml1_pick_place':
            if env_traj is not None:
                task_id = env_traj['task_id']
            ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
            env = ml1.train_classes['pick-place-v2']()
            task = ml1.train_tasks[task_id]
            env.set_task(task)
            env.max_path_length = self.horizon

        elif self.args.env == 'cheetah_vel':
            if env_traj is not None:
                task_id = env_traj['task_id']            
            task_path = f'{self.args.model_ckpt_path}/task_{task_id}/config_cheetah_vel_task{task_id}.pkl'
            with open(task_path, 'rb') as f:
                task_info = pickle.load(f)
            env = HalfCheetahVelEnv(task_info, include_goal = False)   

        elif self.args.env == 'darkroom':
            if env_traj is not None:
                goal = env_traj['goal']
            else:
                goal = np.load(self.args.darkroom_goal_path)[task_id]
            env = DarkroomEnv(dim=self.args.dim, goal=goal, horizon=self.horizon)
                
        self.env = env
    
    def set_batch(self, batch):
        self.batch = batch

    def act(self, state):
        state = torch.tensor(np.array(state)).float().to(self.model.device)
        self.batch['states'] = torch.cat([self.batch['states'], state[None, :]], dim=0)
        action = torch.zeros((1, self.args.action_dim), dtype=torch.float32, device=self.device)
        self.batch['actions'] = torch.cat([self.batch['actions'], action], dim=0)
        action = self.model.get_action(
            states=self.batch['states'],
            actions=self.batch['actions'],
            returns_to_go=torch.ones((*self.batch['states'].shape[:-1], 1), dtype=torch.float32, device=self.device),
            timesteps=torch.arange(self.batch['states'].shape[0], device=self.device),
            prompt=self.batch['prompt']
        )
        if self.args.env == 'darkroom': # convert to one hot with one of the max values
            action = torch.nn.functional.one_hot(action.argmax(), num_classes=self.args.action_dim).to(self.device)
        self.batch['actions'][-1] = action
        return action.detach().cpu().numpy()
    
    def act_optimal(self, state):
        if self.args.env == 'ml1_pick_place':
            return self.policy.get_action(state)
        elif self.args.env == 'cheetah_vel':
            action, _ = self.policy.predict(state, deterministic=True)
        elif self.args.env == 'darkroom':
            return self.env.opt_action(state)
        return action
    
    def _deploy(self):
        states = []
        actions = []
        next_states = []
        rewards = []    

        if self.args.env == 'ml1_pick_place' or self.args.env == 'darkroom':
            state = self.env.reset()
        else:
            state, _ = self.env.reset()
    
        for _ in range(self.horizon):
            action = self.act(state)
            if self.args.env == 'darkroom':
                next_state, reward = self.env.transit(state, action)
            else:
                next_state, reward, done, info = self.env.step(action)

            states.append(state)
            actions.append(action)
            next_states.append(next_state)
            rewards.append(reward)
            state = next_state

        states = np.array(states)
        actions = np.array(actions)
        next_states = np.array(next_states)
        rewards = np.array(rewards)
        return states, actions, next_states, rewards
    
    def deploy_offline(self, env_traj):
        self.set_env(env_traj=env_traj)
        batch = {
            'states': torch.empty((0, self.args.state_dim), dtype=torch.float32, device=self.device),
            'actions': torch.empty(( 0, self.args.action_dim), dtype=torch.float32, device=self.device),
            'prompt': [
                self.convert_to_tensor(env_traj['context_states'][None, :self.prompt_K, :]), # prompt states
                self.convert_to_tensor(env_traj['context_actions'][None, :self.prompt_K, :]), # prompt actions
                self.calc_cum_rewards(self.convert_to_tensor(env_traj['context_rewards'][None, :self.prompt_K, None])), # prompt rewards
                torch.arange(0, self.prompt_K).to(self.device) # prompt timesteps
            ]
        }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        return rewards.sum()

    def deploy_online(self, task_id, Heps):
        rewards_all = []
        self.set_env(task_id=task_id)
        batch = {
            'states': torch.empty((0, self.args.state_dim), dtype=torch.float32, device=self.device),
            'actions': torch.empty(( 0, self.args.action_dim), dtype=torch.float32, device=self.device),
            'prompt': None,
        }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        rewards_all.append(rewards.sum())
        for _ in range(1, Heps):
            batch = {
                'states': torch.empty((0, self.args.state_dim), dtype=torch.float32, device=self.device),
                'actions': torch.empty(( 0, self.args.action_dim), dtype=torch.float32, device=self.device),
                'prompt': [
                    self.convert_to_tensor(states[None, :self.prompt_K, :]), # prompt states
                    self.convert_to_tensor(actions[None, :self.prompt_K, :]), # prompt actions
                    self.calc_cum_rewards(self.convert_to_tensor(rewards[None, :self.prompt_K, None])), # prompt rewards
                    torch.arange(0, self.prompt_K).to(self.device) # prompt timesteps
                ]
            }
            self.set_batch(batch)
            states, actions, next_states, rewards = self._deploy()
            rewards_all.append(rewards.sum())
        return rewards_all
    
    def deploy_optimal(self, task_id):
        self.set_env(task_id=task_id)
        if self.args.env == 'ml1_pick_place':
            self.policy = SawyerPickPlaceV2Policy()
        elif self.args.env == 'cheetah_vel':
            model_ckpt_path = f'{self.args.model_ckpt_path}/task_{task_id}'
            self.policy = SAC.load(f'{model_ckpt_path}/sac_checkpoint_task_{task_id}_best')
        # elif self.args.env == 'darkroom': the env has optimal policy so no need to deploy
        self.act = self.act_optimal
        states, actions, next_states, rewards = self._deploy()
        return rewards.sum()


def online(model, device, args):
    controller = Controller(model, device, args, mode='online')
    rewards_all = []
    for task_id in tqdm((args.train_tasks + args.test_tasks)):
        rewards = controller.deploy(task_id, args.Heps)
        rewards_all.append(rewards)
    rewards_all = np.array(rewards_all)     
    return  rewards_all # in shape of (num_tasks, Heps)


def offline(env_trajs, model, device, args):
    controller = Controller(model, device, args, mode='offline')
    rewards_all = []
    for env_traj in tqdm(env_trajs):
        rewards = controller.deploy(env_traj)
        rewards_all.append(rewards.sum())
    rewards_all = np.array(rewards_all)
    return rewards_all # in shape of (num_tasks)


def optimal(model, device, args):
    controller = Controller(model, device, args, mode='optimal')
    rewards_optimal = []
    for task_id in tqdm((args.train_tasks + args.test_tasks)):
        rewards_optimal.append(controller.deploy(task_id))
    rewards_optimal = np.array(rewards_optimal)
    return rewards_optimal


def evaluation(env_trajs, model, device, epoch, args, writer=None):
    offline_rewards_random = offline(env_trajs[0], model, device, args)
    offline_rewards_expert = offline(env_trajs[1], model, device, args)
    online_rewards = online(model, device, args)
    optimal_rewards = optimal(model, device, args)
    if writer is not None:
        # draw online rewards according to Heps
        for i in range(online_rewards.shape[0]):
            plt.plot(online_rewards[i], color='blue', alpha=0.2)
        plt.plot(online_rewards.mean(axis=0), color='blue', label='Online')
        plt.fill_between(np.arange(args.Heps), 
                         online_rewards.mean(axis=0) - online_rewards.std(axis=0), 
                         online_rewards.mean(axis=0) + online_rewards.std(axis=0), alpha=0.2)
        plt.legend()
        plt.xlabel('Episodes')
        plt.ylabel('Reward')
        plt.title('Online Rewards')
        writer.add_figure('Evaluation/online', plt.gcf(), epoch)
        plt.clf()

        # draw offline rewards bar chart (mean, std, and variance)
        baselines = {
            'Optimal': optimal_rewards,
            'Learner(random)': offline_rewards_random,
            'Learner(expert)': offline_rewards_expert,
        }
        baselines_means = {k: np.mean(v) for k, v in baselines.items()}
        baselines_stds = {k: np.std(v) for k, v in baselines.items()}
        colors = plt.cm.viridis(np.linspace(0, 1, len(baselines_means)))
        plt.bar(baselines_means.keys(), baselines_means.values(), color=colors, yerr=list(baselines_stds.values()), capsize=5)
        plt.ylabel('Average Return')
        plt.title('Offline Rewards')
        writer.add_figure('Evaluation/offline', plt.gcf(), epoch)
        plt.clf()

    return online_rewards, offline_rewards_random, offline_rewards_expert, optimal_rewards
