#import d4rl
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data.dataloader import DataLoader
import torch.distributions.normal as Normal
import random
import pickle
import torch


def K_step_dataset(dataset):
    latent_action = dataset["latent_action"]
    observations = dataset["observations"]
    rewards = dataset["rewards"]

    n = latent_action.size(0)
    K_observations = []
    K_list = []
    K_rewards = []

    for t in range(n):
        current_action = latent_action[t]
        
        if t == n - 1:
            K = 1  
            K_observation = observations[t] 
            K_reward = rewards[t]  
        else:
            K = 1  
            for future_t in range(t + 1, n):
                if latent_action[future_t] != current_action:
                    K = future_t - t
                    break
            K_observation = observations[t + K]            
            K_reward = np.sum(rewards[t: t + K + 1])

        K_observations.append(K_observation)
        K_list.append(K)
        K_rewards.append(K_reward)
    
    dataset["K_observations"] = np.array(K_observations)
    dataset["K"] = np.array(K_list)
    dataset["K_rewards"] = np.array(K_rewards)

    return dataset
    
def modify_reward(dataset):
    rewards = dataset["rewards"]
    terminals = dataset["terminals"]    
    terminal_indices = np.where(terminals == True)[0]
    refined_rewards = np.zeros_like(rewards)
    start_idx = 0
    
    for end_idx in terminal_indices:
        trajectory_rewards = rewards[start_idx:end_idx + 1]
        total_reward = np.sum(trajectory_rewards)
        mean_reward = total_reward / len(trajectory_rewards)
        refined_rewards[start_idx:end_idx + 1] = mean_reward
        start_idx = end_idx + 1
    dataset["rewards"] = refined_rewards
    
    return dataset

# def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
#     """
#     Returns datasets formatted for use by standard Q-learning algorithms,
#     with observations, actions, next_observations, rewards, and a terminal
#     flag.

#     Args:
#         env: An OfflineEnv object.
#         dataset: An optional dataset to pass in for processing. If None,
#             the dataset will default to env.get_dataset()
#         terminate_on_end (bool): Set done=True on the last timestep
#             in a trajectory. Default is False, and will discard the
#             last timestep in each trajectory.
#         **kwargs: Arguments to pass to env.get_dataset().

#     Returns:
#         A dictionary containing keys:
#             observations: An N x dim_obs array of observations.
#             actions: An N x dim_action array of actions.
#             next_observations: An N x dim_obs array of next observations.
#             rewards: An N-dim float array of rewards.
#             terminals: An N-dim boolean array of "done" or episode termination flags.
#     """
#     if dataset is None:
#         dataset = env.get_dataset(**kwargs)
    
#     has_next_obs = True if 'next_observations' in dataset.keys() else False

#     N = dataset['rewards'].shape[0]
#     obs_ = []
#     next_obs_ = []
#     action_ = []
#     reward_ = []
#     done_ = []

#     # The newer version of the dataset adds an explicit
#     # timeouts field. Keep old method for backwards compatability.
#     use_timeouts = False
#     if 'timeouts' in dataset:
#         use_timeouts = True

#     episode_step = 0
#     dataset = modify_reward(dataset)
    
#     for i in range(N-1):
#         obs = dataset['observations'][i].astype(np.float32)
#         if has_next_obs:
#             new_obs = dataset['next_observations'][i].astype(np.float32)
#         else:
#             new_obs = dataset['observations'][i+1].astype(np.float32)
#         action = dataset['actions'][i].astype(np.float32)
#         reward = dataset['rewards'][i].astype(np.float32)
#         done_bool = bool(dataset['terminals'][i])

#         if use_timeouts:
#             final_timestep = dataset['timeouts'][i]
#         else:
#             final_timestep = (episode_step == env._max_episode_steps - 1)
#         # if (not terminate_on_end) and final_timestep:
#         #     # Skip this transition and don't apply terminals on the last step of an episode
#         #     episode_step = 0
#         #     continue  
#         # if done_bool or final_timestep:
#         #     episode_step = 0
#         #     if not has_next_obs:
#         #         continue

#         obs_.append(obs)
#         next_obs_.append(new_obs)
#         action_.append(action)
#         reward_.append(reward)
#         done_.append(done_bool)
#         episode_step += 1

#     return {
#         'observations': np.array(obs_),
#         'actions': np.array(action_),
#         'next_observations': np.array(next_obs_),
#         'rewards': np.array(reward_),
#         'terminals': np.array(done_),
#     }

def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    latent_action_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    dataset = modify_reward(dataset)

    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        latent_action = dataset['latent_action'][i]
        action = dataset["actions"][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        # if use_timeouts:
        #     final_timestep = dataset['timeouts'][i]
        # else:
        #     final_timestep = (episode_step == env._max_episode_steps - 1)
        # if (not terminate_on_end) and final_timestep:
        #     # Skip this transition and don't apply terminals on the last step of an episode
        #     episode_step = 0
        #     continue  
        # if done_bool or final_timestep:
        #     episode_step = 0
        #     if not has_next_obs:
        #         continue

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        latent_action_.append(latent_action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1
    
    dataset = K_step_dataset(dataset)
    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'latent_action': np.array(latent_action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'K_rewards': dataset["K_rewards"],
        'K': dataset["K"],
        'K_observations': dataset["K_observations"]
    }

def save(args, save_name, model, wandb, ep=None):
    import os
    save_dir = './trained_models/' 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not ep == None:
        torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth")
        wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth")
    else:
        torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth")
        wandb.save(save_dir + args.run_name + save_name + ".pth")


def collect_random(env, dataset, num_samples=200):
    state = env.reset()
    for _ in range(num_samples):
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        dataset.add(state, action, reward, next_state, done)
        state = next_state
        if done:
            state = env.reset()


def reparameterize(mean, std):
    eps = torch.normal(torch.zeros(mean.size()).cuda(), torch.ones(mean.size()).cuda())
    return mean + std*eps

def chunks(obs,actions,H,stride):
    '''
    obs is a N x 4 array
    goals is a N x 2 array
    H is length of chunck
    stride is how far we move between chunks.  So if stride=H, chunks are non-overlapping.  If stride < H, they overlap
    '''
    
    obs_chunks = []
    action_chunks = []
    N = obs.shape[0]
    for i in range(N//stride - H):
        start_ind = i*stride
        end_ind = start_ind + H
        
        obs_chunk = torch.tensor(obs[start_ind:end_ind,:],dtype=torch.float32)

        action_chunk = torch.tensor(actions[start_ind:end_ind,:],dtype=torch.float32)
        
        loc_deltas = obs_chunk[1:,:2] - obs_chunk[:-1,:2] #Franka or Maze2d
        
        norms = np.linalg.norm(loc_deltas,axis=-1)
        #USE VALUE FOR THRESHOLD CONDITION BASED ON ENVIRONMENT
        if np.all(norms <= 0.8): #Antmaze large 0.8 medium 0.67 / Franka 0.23 mixed/complete 0.25 partial / Maze2d 0.22
            obs_chunks.append(obs_chunk)
            action_chunks.append(action_chunk)
        else:
            pass

    print('len(obs_chunks): ',len(obs_chunks))
    print('len(action_chunks): ',len(action_chunks))
            
    return torch.stack(obs_chunks),torch.stack(action_chunks)


def get_dataset(env_name, horizon, stride, test_split=0.2, append_goals=False, get_rewards=False, separate_test_trajectories=False, cum_rewards=True):
    dataset_file = 'data/'+env_name+'.pkl'
    with open(dataset_file, "rb") as f:
        dataset = pickle.load(f)

    observations = []
    actions = []
    terminals = []
    if get_rewards:
        rewards = []
    # goals = []

    if env_name == 'antmaze-large-diverse-v2' or env_name == 'antmaze-medium-diverse-v2':

        num_trajectories = np.where(dataset['timeouts'])[0].shape[0]
        assert num_trajectories == 999, 'Dataset has changed. Review the dataset extraction'

        if append_goals:
            dataset['observations'] = np.hstack([dataset['observations'],dataset['infos/goal']])
        print('Total trajectories: ', num_trajectories)

        for traj_idx in range(num_trajectories):
            start_idx = traj_idx * 1001
            end_idx = (traj_idx + 1) * 1001

            obs = dataset['observations'][start_idx : end_idx]
            act = dataset['actions'][start_idx : end_idx]
            if get_rewards:
                rew = np.expand_dims(dataset['rewards'][start_idx : end_idx],axis=1)
                
            # reward = dataset['rewards'][start_idx : end_idx]
            # goal = dataset['infos/goal'][start_idx : end_idx]

            num_observations = obs.shape[0]

            for chunk_idx in range(num_observations // stride - horizon):
                chunk_start_idx = chunk_idx * stride
                chunk_end_idx = chunk_start_idx + horizon

                observations.append(torch.tensor(obs[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
                actions.append(torch.tensor(act[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
                if get_rewards:
                    if np.sum(rew[chunk_start_idx : chunk_end_idx]>0):
                        rewards.append(torch.ones((chunk_end_idx-chunk_start_idx,1), dtype=torch.float32))
                        break
                    rewards.append(torch.tensor(rew[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
                # goals.append(torch.tensor(goal[chunk_start_idx : chunk_end_idx], dtype=torch.float32))

        observations = torch.stack(observations)
        actions = torch.stack(actions)
        if get_rewards:
            rewards = torch.stack(rewards)
        # goals = torch.stack(goals)

        num_samples = observations.shape[0]
        # print(num_samples)
        # assert num_samples == 960039, 'Dataset has changed. Review the dataset extraction'

        print('Total data samples extracted: ', num_samples)
        num_test_samples = int(test_split * num_samples)

        if separate_test_trajectories:
            train_indices = np.arange(0, num_samples - num_test_samples)
            test_indices = np.arange(num_samples - num_test_samples, num_samples)
        else:
            test_indices = np.random.choice(np.arange(num_samples), num_test_samples, replace=False)
            train_indices = np.array(list(set(np.arange(num_samples)) - set(test_indices)))
        np.random.shuffle(train_indices)

        observations_train = observations[train_indices]
        actions_train = actions[train_indices]
        if get_rewards:
            rewards_train = rewards[train_indices]
        else:
            rewards_train = None
        # goals_train = goals[train_indices]

        observations_test = observations[test_indices]
        actions_test = actions[test_indices]
        if get_rewards:
            rewards_test = rewards[test_indices]
        else:
            rewards_test = None
        # goals_test = goals[test_indices]

        return dict(observations_train=observations_train,
                    actions_train=actions_train,
                    rewards_train=rewards_train,
                    # goals_train=goals_train,
                    observations_test=observations_test,
                    actions_test=actions_test,
                    rewards_test=rewards_test,
                    # goals_test=goals_test,
                    )

    elif 'kitchen' in env_name:

        num_trajectories = np.where(dataset['terminals'])[0].shape[0]

        print('Total trajectories: ', num_trajectories)

        terminals = np.where(dataset['terminals'])[0]
        terminals = np.append(-1, terminals)

        for traj_idx in range(1, len(terminals)):
            start_idx = terminals[traj_idx - 1] + 1
            end_idx = terminals[traj_idx] + 1

            obs = dataset['observations'][start_idx : end_idx]
            act = dataset['actions'][start_idx : end_idx]
            rew = np.expand_dims(dataset['rewards'][start_idx : end_idx],axis=1)

            # reward = dataset['rewards'][start_idx : end_idx]
            # goal = dataset['infos/goal'][start_idx : end_idx]

            num_observations = obs.shape[0]

            for chunk_idx in range(num_observations // stride - horizon):
                chunk_start_idx = chunk_idx * stride
                chunk_end_idx = chunk_start_idx + horizon

                observations.append(torch.tensor(obs[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
                actions.append(torch.tensor(act[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
                if cum_rewards:
                    rewards.append(torch.tensor(rew[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
                else:
                    rewards.append(torch.tensor(np.diff(rew[chunk_start_idx : chunk_end_idx], axis=0, prepend=rew[chunk_start_idx, 0]), dtype=torch.float32))
                # goals.append(torch.tensor(goal[chunk_start_idx : chunk_end_idx], dtype=torch.float32))

        observations = torch.stack(observations)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)

        num_samples = observations.shape[0]

        print('Total data samples extracted: ', num_samples)
        num_test_samples = int(test_split * num_samples)

        if separate_test_trajectories:
            train_indices = np.arange(0, num_samples - num_test_samples)
            test_indices = np.arange(num_samples - num_test_samples, num_samples)
        else:
            test_indices = np.random.choice(np.arange(num_samples), num_test_samples, replace=False)
            train_indices = np.array(list(set(np.arange(num_samples)) - set(test_indices)))
        np.random.shuffle(train_indices)

        observations_train = observations[train_indices]
        actions_train = actions[train_indices]
        rewards_train = rewards[train_indices]

        observations_test = observations[test_indices]
        actions_test = actions[test_indices]
        rewards_test = rewards[test_indices]

        return dict(observations_train=observations_train,
                    actions_train=actions_train,
                    rewards_train=rewards_train,
                    observations_test=observations_test,
                    actions_test=actions_test,
                    rewards_test=rewards_test,
                    )

    elif 'maze2d' in env_name:

        if append_goals:
            dataset['observations'] = np.hstack([dataset['observations'], dataset['infos/goal']])

        obs = dataset['observations']
        act = dataset['actions']

        if get_rewards:
            rew = np.expand_dims(dataset['rewards'], axis=1)

        # reward = dataset['rewards'][start_idx : end_idx]
        # goal = dataset['infos/goal'][start_idx : end_idx]

        num_observations = obs.shape[0]

        for chunk_idx in range(num_observations // stride - horizon):
            chunk_start_idx = chunk_idx * stride
            chunk_end_idx = chunk_start_idx + horizon

            observations.append(torch.tensor(obs[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            actions.append(torch.tensor(act[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            if get_rewards:
                rewards.append(torch.tensor(rew[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            # goals.append(torch.tensor(goal[chunk_start_idx : chunk_end_idx], dtype=torch.float32))

        observations = torch.stack(observations)
        actions = torch.stack(actions)
        if get_rewards:
            rewards = torch.stack(rewards)
        # goals = torch.stack(goals)

        num_samples = observations.shape[0]

        print('Total data samples extracted: ', num_samples)
        num_test_samples = int(test_split * num_samples)

        if separate_test_trajectories:
            train_indices = np.arange(0, num_samples - num_test_samples)
            test_indices = np.arange(num_samples - num_test_samples, num_samples)
        else:
            test_indices = np.random.choice(np.arange(num_samples), num_test_samples, replace=False)
            train_indices = np.array(list(set(np.arange(num_samples)) - set(test_indices)))
        np.random.shuffle(train_indices)

        observations_train = observations[train_indices]
        actions_train = actions[train_indices]
        if get_rewards:
            rewards_train = rewards[train_indices]
        else:
            rewards_train = None
        # goals_train = goals[train_indices]

        observations_test = observations[test_indices]
        actions_test = actions[test_indices]
        if get_rewards:
            rewards_test = rewards[test_indices]
        else:
            rewards_test = None
        # goals_test = goals[test_indices]

        return dict(observations_train=observations_train,
                    actions_train=actions_train,
                    rewards_train=rewards_train,
                    # goals_train=goals_train,
                    observations_test=observations_test,
                    actions_test=actions_test,
                    rewards_test=rewards_test,
                    # goals_test=goals_test,
                    )

    else:
        obs = dataset['observations']
        act = dataset['actions']
        rew = np.expand_dims(dataset['rewards'],axis=1)
        dones = np.expand_dims(dataset['terminals'],axis=1)
        episode_step = 0
        chunk_idx = 0

        while chunk_idx < rew.shape[0]-horizon+1:
            chunk_start_idx = chunk_idx
            chunk_end_idx = chunk_start_idx + horizon

            observations.append(torch.tensor(obs[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            actions.append(torch.tensor(act[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            rewards.append(torch.tensor(rew[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            terminals.append(torch.tensor(dones[chunk_start_idx : chunk_end_idx], dtype=torch.float32))
            if np.sum(dones[chunk_start_idx : chunk_end_idx]>0):
                episode_step = 0
                chunk_idx += horizon
            elif(episode_step==1000-horizon):
                episode_step = 0
                chunk_idx += horizon
            else:
                episode_step += 1
                chunk_idx += 1

        observations = torch.stack(observations)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        terminals = torch.stack(terminals)

        num_samples = observations.shape[0]

        print('Total data samples extracted: ', num_samples)
        num_test_samples = int(test_split * num_samples)

        if separate_test_trajectories:
            train_indices = np.arange(0, num_samples - num_test_samples)
            test_indices = np.arange(num_samples - num_test_samples, num_samples)
        else:
            test_indices = np.random.choice(np.arange(num_samples), num_test_samples, replace=False)
            train_indices = np.array(list(set(np.arange(num_samples)) - set(test_indices)))
        np.random.shuffle(train_indices)

        observations_train = observations[train_indices]
        actions_train = actions[train_indices]
        rewards_train = rewards[train_indices]
        terminals_train = terminals[train_indices]

        observations_test = observations[test_indices]
        actions_test = actions[test_indices]
        rewards_test = rewards[test_indices]
        terminals_test = terminals[test_indices]

        return dict(observations_train=observations_train,
                    actions_train=actions_train,
                    rewards_train=rewards_train,
                    terminals_train=terminals_train,
                    observations_test=observations_test,
                    actions_test=actions_test,
                    rewards_test=rewards_test,
                    terminals_test=terminals_test
                    )
