import pickle
import torch
import numpy as np
import metaworld
from tqdm import tqdm
from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
from stable_baselines3 import SAC
from matplotlib import pyplot as plt
import argparse
import common_args
from net import Transformer
from easydict import EasyDict
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import random
import os

def perturb_state(
    state,
    eps_ee=0.01,       # end-effector position
    eps_obj=0.005,     # object position
    eps_quat=0.01,     # object orientation (quaternion)
    eps_goal=0.005     # goal position
):
    """
    Apply bounded perturbation to selected dimensions of a Meta-World state vector.

    Args:
        state: np.ndarray, shape (39,)
        eps_*: float, maximum perturbation magnitude for each group

    Returns:
        perturbed_state: np.ndarray, shape (39,)
    """
    state = state.copy()
    rng = np.random.default_rng()  # modern numpy RNG

    # End-effector position [0:3]
    state[0:3] += rng.uniform(-eps_ee, eps_ee, size=3)
    state[0:3] = np.clip(state[0:3], -0.2, 0.2)

    # Object position [7:10]
    state[7:10] += rng.uniform(-eps_obj, eps_obj, size=3)
    state[7:10] = np.clip(state[7:10], 0.0, 0.2)

    # Object orientation (quaternion) [10:14]
    state[10:14] += rng.uniform(-eps_quat, eps_quat, size=4)
    state[10:14] /= (np.linalg.norm(state[10:14]) + 1e-8)

    # Goal position [35:38]
    state[35:38] += rng.uniform(-eps_goal, eps_goal, size=3)
    state[35:38] = np.clip(state[35:38], 0.0, 0.2)

    return state


class Controller:
    def __init__(self, model, device, args, mode='offline'):
        self.model = model
        self.device = device
        self.args = args
        self.horizon = 100
        if mode == 'offline':
            self.deploy = self.deploy_offline
        elif mode == 'online':
            self.deploy = self.deploy_online
        elif mode == 'optimal':
            self.deploy = self.deploy_optimal
        else:
            raise ValueError(f"Invalid mode: {mode}")

    def convert_to_tensor(self, x):
        return torch.tensor(np.asarray(x)).float().to(self.device)
    
    def set_env(self, env_traj=None, task_id=None):
        if self.args.env == 'ml1_pick_place':
            if env_traj is not None:
                task_id = env_traj['task_id']
            ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
            env = ml1.train_classes['pick-place-v2']()
            task = ml1.train_tasks[task_id]
            env.set_task(task)
            env.max_path_length = self.horizon
       
        elif self.args.env == 'darkroom':
            if env_traj is not None:
                goal = env_traj['goal']
            else:
                goal = np.load(self.args.darkroom_goal_path)[task_id]
            env = DarkroomEnv(dim=self.args.dim, goal=goal, horizon=self.horizon)
        self.env = env

    def set_batch(self, batch):
        self.batch = batch

    def act(self, state):
        device = next(self.model.parameters()).device
        state = torch.tensor(np.array(state)).float().to(device)
        #state = torch.tensor(np.array(state, dtype=np.float32)).to(self.model.device)
        # state = torch.tensor(np.array(state)).float().to(self.model.device)
        self.batch['query_states'] = state[None, :]
       
        action = self.model(self.batch).cpu().detach().numpy()
        return action.squeeze()
    
    def act_optimal(self, state):
        if self.args.env == 'ml1_pick_place':
            return self.policy.get_action(state)
        elif self.args.env == 'cheetah_vel':
            action, _ = self.policy.predict(state, deterministic=True)
        elif self.args.env == 'darkroom':
            return self.env.opt_action(state)
        return action
    

   
    
    
    def deploy_offline(self, env_traj):
        self.set_env(env_traj=env_traj)
        batch = {
            'context_states': self.convert_to_tensor(env_traj['context_states'][None, :, :]),
            'context_actions': self.convert_to_tensor(env_traj['context_actions'][None, :, :]),
            'context_next_states': self.convert_to_tensor(env_traj['context_next_states'][None, :, :]),
            'context_rewards': self.convert_to_tensor(env_traj['context_rewards'][None, :, None]),
        }
        self.set_batch(batch)
        states, actions, next_states, rewards = self._deploy()
        return rewards.sum()
    def deploy_online(self, task_id, Heps):
        rewards_all = []
        self.set_env(task_id=task_id)

        horizon = self.horizon
        state_dim = self.args.state_dim
        action_dim = self.args.action_dim

        # Initialize context as (1, 1, horizon, dim)
        context_states = torch.zeros((1, 1, horizon, state_dim)).float().to(self.device)
        context_actions = torch.zeros((1, 1, horizon, action_dim)).float().to(self.device)
        context_next_states = torch.zeros((1, 1, horizon, state_dim)).float().to(self.device)
        context_rewards = torch.zeros((1, 1, horizon, 1)).float().to(self.device)

        for episode in range(Heps):
            # Reshape context to (1, H*E, dim) for model input
            batch = {
                'context_states': context_states.reshape(1, -1, state_dim),
                'context_actions': context_actions.reshape(1, -1, action_dim),
                'context_next_states': context_next_states.reshape(1, -1, state_dim),
                'context_rewards': context_rewards.reshape(1, -1, 1),
            }

            self.set_batch(batch)

            # Rollout a new episode using current context
            states, actions, next_states, rewards = self._deploy()
            rewards_all.append(rewards.sum())

            # Convert new episode to torch and shape (1, 1, H, dim)
            new_states = self.convert_to_tensor(states[None, :, :])[:, None, :, :]  # (1, 1, H, state_dim)
            new_actions = self.convert_to_tensor(actions[None, :, :])[:, None, :, :]
            new_next_states = self.convert_to_tensor(next_states[None, :, :])[:, None, :, :]
            new_rewards = self.convert_to_tensor(rewards[None, :, None])[:, None, :, :]

            # Append and roll (FIFO) to keep latest `Heps` episodes
            context_states = torch.cat([context_states, new_states], dim=1)[:, -Heps:, :, :]
            context_actions = torch.cat([context_actions, new_actions], dim=1)[:, -Heps:, :, :]
            context_next_states = torch.cat([context_next_states, new_next_states], dim=1)[:, -Heps:, :, :]
            context_rewards = torch.cat([context_rewards, new_rewards], dim=1)[:, -Heps:, :, :]

        return rewards_all
    
    def deploy_optimal(self, task_id):
        self.set_env(task_id=task_id)
        if self.args.env == 'ml1_pick_place':
            self.policy = SawyerPickPlaceV2Policy()
    
        # elif self.args.env == 'darkroom': the env has optimal policy so no need to deploy
        self.act = self.act_optimal
        states, actions, next_states, rewards = self._deploy()
        return rewards.sum()



if __name__ == '__main__':

    

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    common_args.add_model_args(parser)
    common_args.add_eval_args(parser)
    common_args.add_ml1_pick_place_dataset_args(parser)
    parser.add_argument('--seed', type=int, default=0)

    args = vars(parser.parse_args())
    # print("Args: ", args)

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H'] # 100
    state_dim = 39
    action_dim = 4
    n_embd = 32#args['embd']
    n_head = args['head']
    n_layer = 4#args['layer']
    lr = args['lr']
    epoch = args['epoch']
    shuffle = args['shuffle']
    dropout = args['dropout']
    var = args['var']
    cov = args['cov']
    test_cov = args['test_cov']
    envname = args['env']
    horizon = args['hor']
    n_eval = args['n_eval']
    seed = args['seed']
    lin_d = args['lin_d']


    
    
    

    config = {
        'horizon': H,
        'state_dim': state_dim,
        'action_dim': action_dim,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'dropout': dropout,
        'test': True,
    }
    args = EasyDict(vars(parser.parse_args()))


    model = Transformer(config).to(device)
    def convert_to_tensor(x):
        return torch.tensor(np.asarray(x)).float().to(device)

    # train_filepath = 'datasets/ml1_pick_place_H100_q80_n2000_train_iter3.pkl'

    model_path = 'models/metaworld_shufTrue_lr0.001_do0_embd32_layer4_head4_envs100000_hists1_samples1_H100_d10_seed0_epoch50_iter6.pt'
    print('model_path: ', model_path)
    checkpoint = torch.load(model_path)
 
    model.load_state_dict(checkpoint)
    model.eval()
    # online_testing(model, device, args)
    controller = Controller(model, device, args, mode='online')

    rewards_all = []
    trajs = []
    
    final_results = np.zeros((40, 5))
    
    for iters in range(1):# 100
        count = 0
        for task_id in tqdm((args.test_tasks)): #FIXME: change for darkroom  # + args.test_tasks
            # rewards = controller.deploy(task_id, 10) #  args.Heps # def deploy is generate_online_vec_histories
            # within deploy, call def _deploy(), and then call def act


            ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
            env = ml1.train_classes['pick-place-v2']()
            task = ml1.train_tasks[task_id]
            env.set_task(task)
            env.max_path_length = 100

            context_states = torch.zeros(
            (1, 1, 100, 39)).float().to(device)
            context_actions = torch.zeros(
                (1, 1, 100, 4)).float().to(device)
            context_next_states = torch.zeros(
                (1, 1, 100, 39)).float().to(device)
            context_rewards = torch.zeros(
                (1, 1, 100, 1)).float().to(device)

            full_context_states = []
            full_context_actions = []
            full_context_next_states = []
            full_context_rewards = []
        

            batch = {
                'context_states': context_states[:, 0, :, :].reshape(1, -1, 39),
                'context_actions': context_actions[:, 0, :].reshape(1, -1, 4),
                'context_next_states': context_next_states[:, 0, :, :].reshape(1, -1, 39),
                'context_rewards': context_rewards[:, 0, :, :].reshape(1, -1, 1),
            }
            controller.set_batch(batch)

            # states_lnr, actions_lnr, next_states_lnr, rewards_lnr = self._deploy()

            #####
            states = []
            actions = []
            next_states = []
            rewards = []    
        
            state, _ = env.reset()
        

            
            for _ in range(100):
                action = controller.act(state)
                next_state, reward, done, _, info = env.step(action)

                states.append(state)
                actions.append(action)
                next_states.append(next_state)
                rewards.append(reward)
                state = next_state

            states_lnr = np.array(states)
            actions_lnr = np.array(actions)
            next_states_lnr = np.array(next_states)
            rewards_lnr = np.array(rewards)

            # # Sanity check: truncate rollout to horizon before using it
            # states = states[:self.horizon]
            # actions = actions[:self.horizon]
            # next_states = next_states[:self.horizon]
            # rewards = rewards[:self.horizon]
            # ##################
            

            
            context_states[:, 0, :, :] = convert_to_tensor(states_lnr)
            context_actions[:, 0, :, :] = convert_to_tensor(actions_lnr)
            context_next_states[:, 0, :, :] = convert_to_tensor(next_states_lnr)
            context_rewards[:, 0, :, :] = convert_to_tensor(rewards_lnr[None, :, None])

            
            for epis in range(40):
                batch = {
                    'context_states': context_states.reshape(1, -1, 39),
                    'context_actions': context_actions.reshape(1, -1, 4), # (20, 100, 5)
                    'context_next_states': context_next_states.reshape(1, -1, 39),
                    'context_rewards': context_rewards.reshape(1, -1, 1),
                }
                controller.set_batch(batch)
                # states_lnr, actions_lnr, next_states_lnr, rewards_lnr = vec_env.deploy_eval(model)
                states = []
                actions = []
                next_states = []
                rewards = []    
                state, _ = env.reset()
                cum_means = []
                for _ in range(100):
                    action = controller.act(state)
                    next_state, reward, done, _, info = env.step(action)

                    states.append(state)
                    actions.append(action)
                    # next_state = perturb_state(next_state)
                    next_states.append(next_state)
                    rewards.append(reward)
                    state = next_state

                states_lnr = np.array(states)
                actions_lnr = np.array(actions)
                next_states_lnr = np.array(next_states)
                rewards_lnr = np.array(rewards)
                mean = np.sum(rewards_lnr, axis=-1)
                print('mean: ', mean)
                final_results[epis, count] = mean 
                
                # cum_means.append(mean)
                # Save the full rollout from this episode
                full_context_states.append(states_lnr)        # shape: (num_envs, horizon, state_dim)
                full_context_actions.append(actions_lnr)      # shape: (num_envs, horizon, action_dim)
                full_context_next_states.append(next_states_lnr)
                full_context_rewards.append(rewards_lnr)
                # Convert to torch
                # states_lnr = convert_to_tensor(states_lnr)
                # actions_lnr = convert_to_tensor(actions_lnr)
                # next_states_lnr = convert_to_tensor(next_states_lnr)
                # rewards_lnr = convert_to_tensor(rewards_lnr[None, :, None])

                states_lnr = convert_to_tensor(states_lnr[None, None, :, :])
                actions_lnr = convert_to_tensor(actions_lnr[None, None, :, :])
                next_states_lnr = convert_to_tensor(next_states_lnr[None, None, :, :])
                rewards_lnr = convert_to_tensor(rewards_lnr[None, None, :, None])


                # print('context_states[:, 1:, :, :]', context_states[:, 1:, :, :].shape,  states_lnr[:, None, :, :].shape, states_lnr.shape)
                # Roll in new data by shifting the batch and appending the new data.
                context_states = torch.cat(
                    (context_states[:, 1:, :, :], states_lnr), dim=1)
                context_actions = torch.cat(
                    (context_actions[:, 1:, :, :], actions_lnr), dim=1)
                context_next_states = torch.cat(
                    (context_next_states[:, 1:, :, :], next_states_lnr), dim=1)
                context_rewards = torch.cat(
                    (context_rewards[:, 1:, :, :], rewards_lnr), dim=1)
            count += 1
            print('-------------------------------')
            ########################
            full_context_states = np.stack(full_context_states, axis=1)         # (num_envs, Heps, horizon, state_dim)
            full_context_actions = np.stack(full_context_actions, axis=1)       # (num_envs, Heps, horizon, action_dim)
            full_context_next_states = np.stack(full_context_next_states, axis=1)
            full_context_rewards = np.stack(full_context_rewards, axis=1)       # (num_envs, Heps, horizon)
            # After collecting 40 episodes, now build trajectories using only the first 20 episodes
            # print('full_context: ', full_context_states.shape)
            
            # for eps in range(35):

            #     traj = {
            #         'context_states':  full_context_states[:, eps, :].reshape(-1, 39),
            #         'context_actions': full_context_actions[:, eps, :].reshape(-1, 4),
            #         'context_next_states': full_context_next_states[:, eps,  :].reshape(-1, 39),
            #         'context_rewards': full_context_rewards[:, eps].reshape(-1),
            #     }

            #     # Sample query state and optimal action from episodes 37, 38, 39
            #     ep_id = random.randint(37, 40 - 1)  # Randomly pick episode 37/38/39
            #     step_id = random.randint(0, 100 - 1)    # Random step within episode

            #     query_state = full_context_states[step_id, ep_id, :]
            #     optimal_action = full_context_actions[step_id, ep_id, :]

            #     traj['query_state'] = query_state
            #     traj['optimal_action'] = optimal_action

            #     trajs.append(traj)

    #             if len(trajs) % 500 == 0:
                    
    #                 print('traj: ', len(trajs))
    #                 if not os.path.exists('datasets'):
    #                     os.makedirs('datasets', exist_ok=True)
    #                 with open(train_filepath, 'wb') as file:
    #                     pickle.dump(trajs, file)
    #         # print('cum means: ', cum_means)
    #         # all_means_lnr = np.array(np.stack(cum_means, axis=1))
    #         # means_lnr = np.mean(all_means_lnr, axis=0)
    #         # print('cum_means: ', means_lnr, len(means_lnr))
    #         # print('traj: ', len(trajs))

    # if not os.path.exists('datasets'):
    #     os.makedirs('datasets', exist_ok=True)
    # with open(train_filepath, 'wb') as file:
    #     pickle.dump(trajs, file)




    #     print('rewards: ', rewards)
    #     rewards_all.append(rewards)
    # rewards_all = np.array(rewards_all)  


    # generate_online_vec_histories(model, task_id, dim, horizon, Heps=40, num_envs=20)
    print('cum means: ', final_results)

