import os
import torch
import random
import argparse
import pickle
import numpy as np
import metaworld
from tqdm import tqdm
from gym.wrappers import TimeLimit
from train_DPT import build_metaworld_data_filename, Transformer
from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
from collect_ml1_pick_place import overload_env_step, load_sac_model, calculate_advantage

class Controller:
    def __init__(self, model, args, mode='offline'):
        self.model = model
        self.device = torch.device('cuda')
        self.args = args
        self.horizon = args.max_episode_steps
        if mode == 'offline':
            self.deploy = self.deploy_offline
            self.act = self.act_offline
        elif mode == 'online' and args.model_type == 'DPTPR':
            self.deploy = self.deploy_online_pr
            self.act = self.act_online_pr
        elif mode == 'online' and args.model_type == 'DPT':
            self.deploy = self.deploy_online
            self.act = self.act_online
        elif mode == 'deter':
            self.deploy = self.deploy_deter
        elif mode == 'optimal':
            self.deploy = self.deploy_optimal

    def convert_to_tensor(self, x):
        return torch.tensor(np.asarray(x)).float().to(self.device)
    
    def set_env(self, env_traj=None, task_id=None):
        if env_traj is not None:
            task_id = env_traj['task_id']
        ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
        env = ml1.train_classes['pick-place-v2']()
        task = ml1.train_tasks[task_id]
        env.set_task(task)
        env.max_path_length = self.horizon
        env = TimeLimit(env, max_episode_steps=self.horizon)
        env.step = overload_env_step.__get__(env, env.__class__)
        self.task_id = task_id
        self.env = env

    def set_batch(self, batch):
        self.batch = batch

    def _act(self, state):
        state = torch.tensor(np.array(state)).float().to('cuda')
        self.batch['query_states'] = state[None, :]
        action = self.model(self.batch).cpu().detach().numpy()
        return action.squeeze()
    
    def act_offline(self, state):
        model_action = self._act(state)
        if self.act_mode == 'greedy':
            return model_action
        elif self.act_mode == 'gaussian':
            return model_action + np.random.normal(0, 0.01, size=model_action.shape)

    def act_online_pr(self, state):
        model_action = self._act(state)
        model_advantage = calculate_advantage(self.ref_model, state, model_action)
        bad_action, _ = self.bad_model.predict(state, deterministic=True)
        bad_advantage = calculate_advantage(self.bad_model, state, bad_action)
        # return model_action if model_advantage > bad_advantage else bad_action
        return (model_action, bad_action) if model_advantage > bad_advantage else (bad_action, model_action)
    
    def act_online(self, state):
        model_action = self._act(state)
        return model_action    
    
    def act_deter(self, state):
        return self.policy.get_action(state)
    
    def act_optimal(self, state):
        action, _ = self.policy.predict(state, deterministic=True)
        return action
    
    def _deploy(self):
        states = []
        actions = []
        next_states = []
        rewards = []    
        self.env.seed(self.task_id)
        self.env.reset()
        obs = self.env.reset()
        done = False
        while not done:
            states.append(obs)
            action = self.act(obs)
            obs, reward, done, info = self.env.step(action)
            actions.append(action)
            next_states.append(obs)
            rewards.append(reward)
        states = np.array(states)
        actions = np.array(actions)
        next_states = np.array(next_states)
        rewards = np.array(rewards)
        return states, actions, next_states, rewards
    
    def _deploy_pr(self):
        states = []
        pr_actions = []
        non_pr_actions = []
        next_states = []
        rewards = []    
        self.env.seed(self.task_id)
        self.env.reset()
        obs = self.env.reset()
        done = False
        while not done:
            states.append(obs)
            pr_action, non_pr_action = self.act(obs)
            pr_actions.append(pr_action)
            non_pr_actions.append(non_pr_action)
            obs, reward, done, info = self.env.step(pr_action)
            next_states.append(obs)
            rewards.append(reward)
        states = np.array(states)
        pr_actions = np.array(pr_actions)
        non_pr_actions = np.array(non_pr_actions)
        next_states = np.array(next_states)
        rewards = np.array(rewards)
        return states, pr_actions, non_pr_actions, next_states, rewards    

    def deploy_offline(self, env_traj, args):
        self.set_env(env_traj=env_traj)
        if args.model_type == 'DPTPR':
            batch = {
                'context_states': self.convert_to_tensor(env_traj['context_states'])[None, :args.max_episode_steps, :],
                'preferred_actions': self.convert_to_tensor(env_traj['preferred_actions'])[None, :args.max_episode_steps, :],
                'non_preferred_actions': self.convert_to_tensor(env_traj['non_preferred_actions'])[None, :args.max_episode_steps, :],
                'context_next_states': self.convert_to_tensor(env_traj['context_next_states'])[None, :args.max_episode_steps, :],
                'zeros': self.convert_to_tensor(np.zeros(args.state_dim ** 2 + args.action_dim + 1))[None, :],
            }
        else:
            batch = {
                'context_states': self.convert_to_tensor(env_traj['context_states'])[None, :args.max_episode_steps, :],
                'context_actions': self.convert_to_tensor(env_traj['context_actions'])[None, :args.max_episode_steps, :],
                'context_next_states': self.convert_to_tensor(env_traj['context_next_states'])[None, :args.max_episode_steps, :],
                'context_rewards': self.convert_to_tensor(env_traj['context_rewards'])[...,None][None, :args.max_episode_steps, :],
                'zeros': self.convert_to_tensor(np.zeros(args.state_dim ** 2 + args.action_dim + 1))[None, :],
            }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        return rewards # shape of [max_episode_steps, 1]

    def deploy_online_pr(self, task_id, args):
        rewards_all = []
        self.set_env(task_id=task_id)
        self.ref_model, _, self.bad_model = load_sac_model(task_id, args.p_good, args.p_bad)
        batch = {
            'context_states': torch.empty((1, 0, self.args.state_dim)).float().to(self.device),
            'preferred_actions': torch.empty((1, 0, self.args.action_dim)).float().to(self.device),
            'non_preferred_actions': torch.empty((1, 0, self.args.action_dim)).float().to(self.device),
            'context_next_states': torch.empty((1, 0, self.args.state_dim)).float().to(self.device),
            'zeros': self.convert_to_tensor(np.zeros(args.state_dim ** 2 + args.action_dim + 1))[None, :],
        }
        self.set_batch(batch)
        states, pr_actions, non_pr_actions, next_states, rewards = self._deploy_pr()
        rewards_all.append(rewards)
        for _ in range(1, args.Heps):
            batch = {
                'context_states': self.convert_to_tensor(states[None, :, :]),
                'preferred_actions': self.convert_to_tensor(pr_actions[None, :, :]),
                'non_preferred_actions': self.convert_to_tensor(non_pr_actions[None, :, :]),
                'context_next_states': self.convert_to_tensor(next_states[None, :, :]),
                'zeros': self.convert_to_tensor(np.zeros(args.state_dim ** 2 + args.action_dim + 1))[None, :],
            }
            self.set_batch(batch)
            states, pr_actions, non_pr_actions, next_states, rewards = self._deploy_pr()
            rewards_all.append(rewards)
        # shape of [Heps, max_episode_steps, 1]
        rewards_all = np.array(rewards_all)
        return rewards_all
    
    def deploy_online(self, task_id, args): # used for
        rewards_all = []
        self.set_env(task_id=task_id)
        self.ref_model, _, self.bad_model = load_sac_model(task_id, args.p_good, args.p_bad)
        batch = {
            'context_states': torch.empty((1, 0, self.args.state_dim)).float().to(self.device),
            'context_actions': torch.empty((1, 0, self.args.action_dim)).float().to(self.device),
            'context_next_states': torch.empty((1, 0, self.args.state_dim)).float().to(self.device),
            'context_rewards': torch.empty((1, 0, 1)).float().to(self.device),
            'zeros': self.convert_to_tensor(np.zeros(args.state_dim ** 2 + args.action_dim + 1))[None, :],
        }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        rewards_all.append(rewards)
        for _ in range(1, args.Heps):
            batch = {
                'context_states': self.convert_to_tensor(states[None, :, :]),
                'context_actions': self.convert_to_tensor(actions[None, :, :]),
                'context_next_states': self.convert_to_tensor(next_states[None, :, :]),
                'context_rewards': self.convert_to_tensor(rewards[None, :, None]),
                'zeros': self.convert_to_tensor(np.zeros(args.state_dim ** 2 + args.action_dim + 1))[None, :],
            }
            self.set_batch(batch)
            states, actions, next_states, rewards = self._deploy()
            rewards_all.append(rewards)
        # shape of [Heps, max_episode_steps, 1]
        rewards_all = np.array(rewards_all)
        return rewards_all    
    
    def deploy_deter(self, task_id, args):
        self.set_env(task_id=task_id)
        self.policy = SawyerPickPlaceV2Policy()
        self.act = self.act_deter
        states, actions, next_states, rewards = self.deploy_offline()
        return rewards # shape of [max_episode_steps, 1]

    def deploy_optimal(self, task_id, args):
        self.set_env(task_id=task_id)
        self.policy, _, _ = load_sac_model(task_id, args.p_good, args.p_bad)
        self.act = self.act_optimal
        states, actions, next_states, rewards = self.deploy_offline()
        return rewards # shape of [max_episode_steps, 1]


def eval(model, env_trajs, act_mode, mode, args):
    controller = Controller(model, args, mode=mode)
    controller.act_mode = act_mode
    rewards = []
    bar = tqdm(env_trajs)
    for eval_item in bar:
        rewards.append(controller.deploy(eval_item, args))
    rewards = np.array(rewards)
    return rewards


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Evaluate a Decision Preference Transformer on MetaWorld data")
    parser.add_argument('--model_type', type=str, default='DPTPR', choices=['DPTPR', 'DPT'], help="Model type")
    # Dataset parameters
    parser.add_argument('--device', type=str, default='cuda', help="Device to run on (cuda/cpu)")
    # parser.add_argument('--test_tasks', type=list, default=[1,2,3,4,5,6,7,8,9,
    #                                                         11,12,13,14,15,16,17,18,19,
    #                                                         21,22,23,24,25,26,27,28,29,
    #                                                         31,32,33,34,35,36,37,38,39,
    #                                                         41,42,43,44,45,46,47,48,49], help="Tasks to train on")
    parser.add_argument('--eval_tasks', type=list, default=[0,10,20,30,40], help="Tasks to eval on")
    parser.add_argument('--max-episode-steps', type=int, default=100, help="Max episode steps")
    # Evaluation parameters
    parser.add_argument('--Heps', type=int, default=50, help="Horizon")
    parser.add_argument('--p-good', type=int, default=100, help="Good policy probability")
    parser.add_argument('--p-bad', type=int, default=20, help="Bad policy probability")
    parser.add_argument('--n-trajs', type=int, default=50, help="Number of trajectories")
    parser.add_argument('--state-dim', type=int, default=39, help="State dimension")
    parser.add_argument('--action-dim', type=int, default=4, help="Action dimension")
    return parser.parse_args()



if __name__ == '__main__':
    # Parse arguments
    args = parse_args()
    print("Args: ", args)
    # Set random seeds
    seed = 42
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    # load weights
    model_name = f"h{args.max_episode_steps}_{args.model_type}"
    model = torch.load(os.path.join('models', f"{model_name}.pt"), map_location='cuda', weights_only=False)
    model.test = True
    # check if already have results
    if os.path.exists(os.path.join('evaluations', f"{model_name}.pkl")):
        with open(os.path.join('evaluations', f"{model_name}.pkl"), 'rb') as f:
            results = pickle.load(f)
    else:
        results = {} 

    # Generate eval trajectories
    for p_good in [20, 40, 50, 60, 80, 100]:
        eval_path = build_metaworld_data_filename(
            len(args.eval_tasks), args.n_trajs, p_good, args.p_bad, args.model_type, mode=2)
        with open(eval_path, 'rb') as f:
            eval_trajs = pickle.load(f)   
        # eval_trajs = eval_trajs[::5]
        offline_rewards_gaussian = eval(model, eval_trajs, 'gaussian', 'offline', args) # shape of [n_trajs, max_episode_steps, 1]
        offline_rewards_greedy = eval(model, eval_trajs, 'greedy', 'offline', args) # shape of [n_trajs, max_episode_steps, 1]
        print(f"Offline rewards (gaussian) on {p_good}: {np.sum(offline_rewards_gaussian,axis=-1)}")
        print(f"Offline rewards (greedy) on {p_good}: {np.sum(offline_rewards_greedy,axis=-1)}")
        results[f'offline_gaussian_{p_good}'] = offline_rewards_gaussian
        results[f'offline_greedy_{p_good}'] = offline_rewards_greedy
    online_rewards = eval(model, args.eval_tasks*10, None, 'online', args)
    print(f"Online rewards: {np.sum(online_rewards,axis=-1)}")
    results[f'online'] = online_rewards
    os.makedirs('evaluations', exist_ok=True)
    with open(os.path.join('evaluations', f"{model_name}.pkl"), 'wb') as f:
        pickle.dump(results, f)