import numpy as np
import torch
import torch.nn as nn
from typing import Callable, Optional
import copy

from torch.nn.modules.dropout import Dropout
from pathlib import Path
from torch.nn import functional as F
from torch.distributions import Normal, kl_divergence

import gym
from pathlib import Path
import h5py
from tqdm import tqdm
import d4rl

def get_keys(h5file):
    keys = []

    def visitor(name, item):
        if isinstance(item, h5py.Dataset):
            keys.append(name)

    h5file.visititems(visitor)
    return keys

def call_otdf_cost(env_name, src_datatype, tar_datatype):
    if '_' in env_name:
        env_name = env_name.replace('_', '-')
    
    cost_path = str(Path(__file__).parent.parent.absolute()) + '/costlogs/' + env_name + '-srcdatatype-' + str(src_datatype) + '-tardatatype-' + str(tar_datatype) + '.hdf5'

    data_dict = {}
    with h5py.File(cost_path, 'r') as dataset_file:
        for k in tqdm(get_keys(dataset_file), desc="load datafile"):
            try:  # first try loading as an array
                data_dict[k] = dataset_file[k][:]
            except ValueError as e:  # try loading as a scalar
                data_dict[k] = dataset_file[k][()]
        
    dataset = data_dict

    return dataset['cost']

# def call_tar_dataset(tar_env_name, tar_datatype):
#     if '-' in tar_env_name:
#         tar_env_name = tar_env_name.replace('-', '_')

#     if any(name in tar_env_name for name in ['halfcheetah', 'hopper', 'walker2d']) or tar_env_name.split('_')[0] == 'ant':
#         domain = 'mujoco'
#         make_env_name = tar_env_name.split('_')[0]
#         env = gym.make(make_env_name + '-medium-v2')
#         _max_episode_steps = env._max_episode_steps
#     else:
#         raise NotImplementedError
    
#     if 'gravity' in tar_env_name:
#         tar_dataset_path = str(Path(__file__).parent.parent.absolute()) + '/datasets/' + tar_env_name + '_0.5_' + str(tar_datatype.replace('-', '_')) + '.hdf5'
#     elif 'morph' in tar_env_name:
#         tar_dataset_path = str(Path(__file__).parent.parent.absolute()) + '/datasets/' + tar_env_name + '_' + str(tar_datatype.replace('-', '_')) + '.hdf5'
#     else:
#         tar_dataset_path = str(Path(__file__).parent.parent.absolute()) + '/datasets/' + tar_env_name + '_kinematic_' + str(tar_datatype.replace('-', '_')) + '.hdf5'


#     data_dict = {}
#     with h5py.File(tar_dataset_path, 'r') as dataset_file:
#         for k in tqdm(get_keys(dataset_file), desc="load datafile"):
#             try:  # first try loading as an array
#                 data_dict[k] = dataset_file[k][:]
#             except ValueError as e:  # try loading as a scalar
#                 data_dict[k] = dataset_file[k][()]
        
#     dataset = data_dict
    
#     N = dataset['rewards'].shape[0]
#     obs_ = []
#     next_obs_ = []
#     action_ = []
#     reward_ = []
#     done_ = []

#     # The newer version of the dataset adds an explicit
#     # timeouts field. Keep old method for backwards compatability.
#     use_timeouts = False
#     if 'timeouts' in dataset:
#         use_timeouts = True

#     episode_step = 0
#     # count how many trajectories are included, ensure that the quantity of trajectories do not exceed number_of_trajectories
#     counter = 0
#     for i in range(N-1):
#         obs = dataset['observations'][i].astype(np.float32)
#         new_obs = dataset['observations'][i+1].astype(np.float32)
#         action = dataset['actions'][i].astype(np.float32)
#         try:
#             reward = dataset['rewards'][i].astype(np.float32)[0]
#         except:
#             reward = dataset['rewards'][i].astype(np.float32)
#         done_bool = bool(dataset['terminals'][i])

#         if use_timeouts:
#             final_timestep = dataset['timeouts'][i]
#         else:
#             final_timestep = (episode_step == _max_episode_steps - 1)

#         if done_bool or final_timestep:
#             counter +=1
#             episode_step = 0

#         obs_.append(obs)
#         next_obs_.append(new_obs)
#         action_.append(action)
#         reward_.append(reward)
#         done_.append(done_bool)
#         episode_step += 1

#     return {
#         'observations': np.array(obs_),
#         'actions': np.array(action_),
#         'next_observations': np.array(next_obs_),
#         'rewards': np.array(reward_),
#         'terminals': np.array(done_),
#     }
    
def convert_to_tensor(data, device):
    """
    Convert data to a PyTorch tensor and move it to the specified device.
    """
    if isinstance(data, np.ndarray):
        return torch.tensor(data, device=device)
    elif isinstance(data, torch.Tensor):
        return data.to(device)
    else:
        raise TypeError(f"Unsupported data type: {type(data)}. Expected numpy array or torch tensor.")

def expectile_regression(pred, target, expectile):
    diff = target - pred
    return torch.where(diff > 0, expectile, 1-expectile) * (diff**2)

def make_target(m: nn.Module) -> nn.Module:
    target = copy.deepcopy(m)
    target.requires_grad_(False)
    target.eval()
    return target

def compute_action_weight(distance):
    """
    Compute the weight for the critic loss based on the OT distance.
    The weight is computed as the exponentiation of the normalized distance.
    """
    # need to be implemented
    pass

class ReplayBuffer(object):
    def __init__(self, state_dim, action_dim, device, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0

        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.next_state = np.zeros((max_size, state_dim))
        self.reward = np.zeros((max_size, 1))
        self.not_done = np.zeros((max_size, 1))

        self.device = device

    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)


    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)

        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.FloatTensor(self.action[ind]).to(self.device),
            torch.FloatTensor(self.next_state[ind]).to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device)
        )
    
    def convert_D4RL(self, dataset):
        self.state = dataset['observations']
        self.action = dataset['actions']
        self.next_state = dataset['next_observations']
        self.reward = dataset['rewards'].reshape(-1,1)
        self.not_done = 1. - dataset['terminals'].reshape(-1,1)
        self.size = self.state.shape[0]


class OTReplayBuffer(object):
    def __init__(self, state_dim, action_dim, device, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0

        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.next_state = np.zeros((max_size, state_dim))
        self.reward = np.zeros((max_size, 1))
        self.not_done = np.zeros((max_size, 1))
        # set cost to max value
        self.cost = np.ones((max_size, 1)) * 1e8

        self.device = device

    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)


    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)

        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.FloatTensor(self.action[ind]).to(self.device),
            torch.FloatTensor(self.next_state[ind]).to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device),
            torch.FloatTensor(self.cost[ind]).to(self.device)
        )
    
    def convert_D4RL(self, dataset):
        self.state = dataset['observations']
        self.action = dataset['actions']
        self.next_state = dataset['next_observations']
        self.reward = dataset['rewards'].reshape(-1,1)
        self.not_done = 1. - dataset['terminals'].reshape(-1,1)
        self.size = self.state.shape[0]
    
    def preprocess(self, filter_num):
        indices = np.argpartition(self.cost, filter_num)[:filter_num]

        self.state = self.state[indices]
        self.action = self.action[indices]
        self.next_state = self.next_state[indices]
        self.reward = self.reward[indices]
        self.not_done = self.not_done[indices]
        self.cost = self.cost[indices]
        self.size = self.state.shape[0]

# class SequenceReplayBuffer(object):
#     def __init__(self, state_dim, action_dim, device, max_size=int(1e6), seq_len=20):
#         self.max_size = max_size
#         self.seq_len = seq_len
#         self.ptr = 0
#         self.size = 0

#         self.state = np.zeros((max_size, state_dim))
#         self.action = np.zeros((max_size, action_dim))
#         self.next_state = np.zeros((max_size, state_dim))
#         self.reward = np.zeros((max_size, 1))
#         self.not_done = np.zeros((max_size, 1))
#         self.timestep = np.zeros((max_size, 1))
#         self.cost = np.zeros((max_size, 1))

#         self.device = device

#     def add(self, state, action, next_state, reward, done):
#         self.state[self.ptr] = state
#         self.action[self.ptr] = action
#         self.next_state[self.ptr] = next_state
#         self.reward[self.ptr] = reward
#         self.not_done[self.ptr] = 1. - done
#         self.size = min(self.size + 1, self.max_size)
#         prev_ptr = (self.ptr + self.size - 1) % self.size
#         if self.size == 1 or (self.not_done[prev_ptr] == 0 and self.not_done[self.ptr] == 1):
#             self.timestep[self.ptr] = 0
#         else:
#             self.timestep[self.ptr] = self.timestep[prev_ptr] + 1
#         self.ptr = (self.ptr + 1) % self.max_size


#     def sample(self, batch_size):
#         ind = np.random.randint(0, self.size, size=batch_size)

#         return (
#             torch.FloatTensor(self.state[ind]).to(self.device),
#             torch.FloatTensor(self.action[ind]).to(self.device),
#             torch.FloatTensor(self.next_state[ind]).to(self.device),
#             torch.FloatTensor(self.reward[ind]).to(self.device),
#             torch.FloatTensor(self.not_done[ind]).to(self.device),
#             torch.FloatTensor(self.timestep[ind]).to(self.device),
#             torch.FloatTensor(self.cost[ind]).to(self.device)
#         )
    
#     def get_sequence(self, start):
#         """
#         Get a sequence starting from the given index.
#         The sequence will continue until a terminal state is reached or the maximum sequence length is reached.
#         """
#         states = []
#         actions = []
#         next_states = []
#         rewards = []
#         not_dones = []
#         timesteps = []
#         costs = []

#         current_index = start
#         while True:
#             states.append(self.state[current_index])
#             actions.append(self.action[current_index])
#             next_states.append(self.next_state[current_index])
#             rewards.append(self.reward[current_index])
#             not_dones.append(self.not_done[current_index])
#             timesteps.append(self.timestep[current_index])
#             costs.append(self.cost[current_index])

#             if self.not_done[current_index] == 0 or len(states) >= self.seq_len:
#                 break
            
#             current_index = (current_index + 1) % self.size
        
#         if len(states) < self.seq_len:
#             # zero pad on left side
#             padding_length = self.seq_len - len(states)
#             states = np.concatenate((np.zeros((padding_length, states[0].shape[0])), states), axis=0)
#             actions = np.concatenate((np.zeros((padding_length, actions[0].shape[0])), actions), axis=0)
#             next_states = np.concatenate((np.zeros((padding_length, next_states[0].shape[0])), next_states), axis=0)
#             rewards = np.concatenate((np.zeros((padding_length, rewards[0].shape[0])), rewards), axis=0)
#             not_dones = np.concatenate((np.zeros((padding_length, not_dones[0].shape[0])), not_dones), axis=0)
#             timesteps = np.concatenate((np.zeros((padding_length, timesteps[0].shape[0])), timesteps), axis=0)
#             costs = np.concatenate((np.zeros((padding_length, costs[0].shape[0])), costs), axis=0)

#         return (
#             np.array(states),
#             np.array(actions),
#             np.array(next_states),
#             np.array(rewards),
#             np.array(not_dones),
#             np.array(timesteps),
#             np.array(costs)
#         )
    
#     def sample_sequence(self):
#         start = np.random.randint(0, self.size)
#         while self.not_done[start] == 0 or self.not_done[(start + 1) % self.size] == 0:
#             start = np.random.randint(0, self.size)
#         state, action, next_state, reward, not_done, timestep, cost = self.get_sequence(start)
#         return (
#             torch.FloatTensor(state).to(self.device),
#             torch.FloatTensor(action).to(self.device),
#             torch.FloatTensor(next_state).to(self.device),
#             torch.FloatTensor(reward).to(self.device),
#             torch.FloatTensor(not_done).to(self.device),
#             torch.FloatTensor(timestep).to(self.device),
#             torch.FloatTensor(cost).to(self.device)
#         )
        
#     def sample_sequence_batch(self, batch_size):
#         states = []
#         actions = []
#         next_states = []
#         rewards = []
#         not_dones = []
#         timesteps = []
#         costs = []

#         for _ in range(batch_size):
#             state, action, next_state, reward, not_done, timestep, cost = self.sample_sequence()
#             states.append(state)
#             actions.append(action)
#             next_states.append(next_state)
#             rewards.append(reward)
#             not_dones.append(not_done)
#             timesteps.append(timestep)
#             costs.append(cost)

#         return (
#             torch.stack(states).to(self.device),
#             torch.stack(actions).to(self.device),
#             torch.stack(next_states).to(self.device),
#             torch.stack(rewards).to(self.device),
#             torch.stack(not_dones).to(self.device),
#             torch.stack(timesteps).to(self.device),
#             torch.stack(costs).to(self.device)
#         )

#     def convert_D4RL(self, dataset):
#         self.state = dataset['observations']
#         self.action = dataset['actions']
#         self.next_state = dataset['next_observations']
#         self.reward = dataset['rewards'].reshape(-1,1)
#         self.not_done = 1. - dataset['terminals'].reshape(-1,1)
#         self.size = self.state.shape[0]
#         self.timestep = np.zeros((self.size, 1))
#         self.cost = np.zeros((self.size, 1))
#         for i in range(1, self.size):
#             if self.not_done[i-1] == 0:
#                 self.timestep[i] = 0
#             else:
#                 self.timestep[i] = self.timestep[i-1] + 1
    
#     def preprocess(self, filter_num):
#         # get filter_num smallest cost indices
#         indices = np.argpartition(self.cost[:self.size].flatten(), filter_num)[:filter_num]

#         self.state = self.state[indices]
#         self.action = self.action[indices]
#         self.next_state = self.next_state[indices]
#         self.reward = self.reward[indices]
#         self.not_done = self.not_done[indices]
#         self.cost = self.cost[indices]
#         self.timestep = self.timestep[indices]
#         self.size = self.state.shape[0]

#     def _get_slice(self, arr, start, end):
#         if end < start:
#             return np.concatenate((arr[start:], arr[:end]), axis=0)
#         else:
#             return arr[start:end]


class MLP(nn.Module):

    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim,
        n_layers,
        activations: Callable = nn.ReLU,
        activate_final: int = False,
        dropout_rate: Optional[float] = None
    ) -> None:
        super().__init__()

        self.affines = []
        self.affines.append(nn.Linear(in_dim, hidden_dim))
        for i in range(n_layers-2):
            self.affines.append(nn.Linear(hidden_dim, hidden_dim))
        self.affines.append(nn.Linear(hidden_dim, out_dim))
        self.affines = nn.ModuleList(self.affines)

        self.activations = activations()
        self.activate_final = activate_final
        self.dropout_rate = dropout_rate
        if dropout_rate is not None:
            self.dropout = Dropout(self.dropout_rate)
            self.norm_layer = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        for i in range(len(self.affines)):
            x = self.affines[i](x)
            if i != len(self.affines)-1 or self.activate_final:
                x = self.activations(x)
                if self.dropout_rate is not None:
                    x = self.dropout(x)
                    # x = self.norm_layer(x)
        return x

def identity(x):
    return x

def fanin_init(tensor, scale=1):
    size = tensor.size()
    if len(size) == 2:
        fan_in = size[0]
    elif len(size) > 2:
        fan_in = np.prod(size[1:])
    else:
        raise Exception("Shape must be have dimension at least 2.")
    bound = scale / np.sqrt(fan_in)
    return tensor.data.uniform_(-bound, bound)

def orthogonal_init(tensor, gain=0.01):
    torch.nn.init.orthogonal_(tensor, gain=gain)

class ParallelizedLayerMLP(nn.Module):

    def __init__(
        self,
        ensemble_size,
        input_dim,
        output_dim,
        w_std_value=1.0,
        b_init_value=0.0
    ):
        super().__init__()

        # approximation to truncated normal of 2 stds
        w_init = torch.randn((ensemble_size, input_dim, output_dim))
        w_init = torch.fmod(w_init, 2) * w_std_value
        self.W = nn.Parameter(w_init, requires_grad=True)

        # constant initialization
        b_init = torch.zeros((ensemble_size, 1, output_dim)).float()
        b_init += b_init_value
        self.b = nn.Parameter(b_init, requires_grad=True)

    def forward(self, x):
        # assumes x is 3D: (ensemble_size, batch_size, dimension)
        return x @ self.W + self.b


class ParallelizedEnsembleFlattenMLP(nn.Module):

    def __init__(
            self,
            ensemble_size,
            hidden_sizes,
            input_size,
            output_size,
            init_w=3e-3,
            hidden_init=fanin_init,
            w_scale=1,
            b_init_value=0.1,
            layer_norm=None,
            final_init_scale=None,
            dropout_rate=None,
    ):
        super().__init__()

        self.ensemble_size = ensemble_size
        self.input_size = input_size
        self.output_size = output_size
        self.elites = [i for i in range(self.ensemble_size)]

        self.sampler = np.random.default_rng()

        self.hidden_activation = F.relu
        self.output_activation = identity
        
        self.layer_norm = layer_norm

        self.fcs = []

        self.dropout_rate = dropout_rate
        if self.dropout_rate is not None:
            self.dropout = Dropout(self.dropout_rate)

        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = ParallelizedLayerMLP(
                ensemble_size=ensemble_size,
                input_dim=in_size,
                output_dim=next_size,
            )
            for j in self.elites:
                hidden_init(fc.W[j], w_scale)
                fc.b[j].data.fill_(b_init_value)
            self.__setattr__('fc%d'% i, fc)
            self.fcs.append(fc)
            in_size = next_size

        self.last_fc = ParallelizedLayerMLP(
            ensemble_size=ensemble_size,
            input_dim=in_size,
            output_dim=output_size,
        )
        if final_init_scale is None:
            self.last_fc.W.data.uniform_(-init_w, init_w)
            self.last_fc.b.data.uniform_(-init_w, init_w)
        else:
            for j in self.elites:
                orthogonal_init(self.last_fc.W[j], final_init_scale)
                self.last_fc.b[j].data.fill_(0)

    def forward(self, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=-1)

        state_dim = inputs[0].shape[-1]
        
        dim=len(flat_inputs.shape)
        # repeat h to make amenable to parallelization
        # if dim = 3, then we probably already did this somewhere else
        # (e.g. bootstrapping in training optimization)
        if dim < 3:
            flat_inputs = flat_inputs.unsqueeze(0)
            if dim == 1:
                flat_inputs = flat_inputs.unsqueeze(0)
            flat_inputs = flat_inputs.repeat(self.ensemble_size, 1, 1)
        
        # input normalization
        h = flat_inputs

        # standard feedforward network
        for _, fc in enumerate(self.fcs):
            h = fc(h)
            h = self.hidden_activation(h)
            # add dropout
            if self.dropout_rate:
                h = self.dropout(h)
            if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
                h = self.layer_norm(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)

        # if original dim was 1D, squeeze the extra created layer
        if dim == 1:
            output = output.squeeze(1)

        # output is (ensemble_size, batch_size, output_size)
        return output
    
    def sample(self, *inputs):
        preds = self.forward(*inputs)

        sample_idxs = np.random.choice(self.ensemble_size, 2, replace=False)
        preds_sample = preds[sample_idxs]
        
        return torch.min(preds_sample, dim=0)[0], sample_idxs