import math
import torch

import nn_plot
import pickle
from nn_util import compute_distance_torch, compute_distance_with_rot_torch, compute_accum_distance_torch, load_and_scale_data, set_seed, create_matrices
from logging_utils import logger
from plotting_utils import quick_plot

DEBUG = False

class NN_METHOD:
    NN, NS, KNN, KNN_AND_DELTA = range(4)

    @classmethod
    def from_string(cls, name):
        match name:
            case 'nn':
                return NN_METHOD.NN
            case 'ns':
                return NN_METHOD.NS
            case 'knn':
                return NN_METHOD.KNN
            case 'knn_and_delta':
                return NN_METHOD.KNN_AND_DELTA
            case _:
                logger.warning(f"No such method {name}! Defaulting to NN")
                return NN_METHOD.NN

class NNAgent:
    #@profile
    def __init__(self, env_cfg, policy_cfg):
        #print(f"Seeding with {env_cfg.get('seed', 42)}")
        set_seed(env_cfg.get("seed", 42))
        self.env_cfg = env_cfg
        self.policy_cfg = policy_cfg

        self.device = env_cfg['device']
        self.to_device(self.device)
        self.method = NN_METHOD.from_string(policy_cfg.get('method'))

        # If this is already defined, a subclass has intentionally set it
        if not hasattr(self, 'datasets'):
            self.datasets = {}
            # We may use different datasets for retrieval, neighbor state, and state delta
            if env_cfg.get('mixed', False):
                # Lookup dict for duplicate datasets
                paths = {}
                for dataset in ['retrieval', 'state', 'delta_state']:
                    path = env_cfg[dataset]['demo_pkl']

                    # Check for duplicates
                    if path in paths.keys():
                        self.datasets[dataset] = self.datasets[paths[path]]
                    else:
                        paths[path] = dataset

                        self.datasets[dataset] = load_and_scale_data(
                            path,
                            env_cfg[dataset].get('rot_indices', []),
                            env_cfg[dataset].get('weights', []),
                            use_torch=True,
                            scale=False
                        )
            else:
                expert_data_path = env_cfg['demo_pkl']
                one_dataset = load_and_scale_data(
                    expert_data_path,
                    env_cfg.get('rot_indices', []),
                    env_cfg.get('weights', []),
                    ob_type=env_cfg.get('type', 'state'),
                    use_torch=True,
                    scale=False
                )

                for dataset in ['retrieval', 'state', 'delta_state']:
                    self.datasets[dataset] = one_dataset

        self.rot_indices = self.datasets['retrieval'].rot_indices
        self.non_rot_indices = self.datasets['retrieval'].non_rot_indices

        #if self.datasets['retrieval'].flattened_obs_matrix is not None:
        #    print(f"Dataset mean for {self.datasets['retrieval'].name} is {torch.mean(self.datasets['retrieval'].flattened_obs_matrix, axis=0)}")

        self.candidates = policy_cfg.get('k', 100)
        self.lookback = policy_cfg.get('lookback', 10)
        self.decay = policy_cfg.get('decay_rate', 1)
        self.window = policy_cfg.get('dtw_window', 0)
        self.final_neighbors_ratio = policy_cfg.get('final_neighbors_ratio', 1)
        self.obs_horizon = policy_cfg.get('obs_horizon', 1)

        self.env_cfg = env_cfg

        if env_cfg.get("rgb_demo_pkl"):
            rgb_obs_matrix, _, _ = create_matrices(pickle.load(open(env_cfg['rgb_demo_pkl'], 'rb')), use_torch=True)
            self.rgb_obs = torch.cat([torch.as_tensor(obs, device=self.device) for obs in rgb_obs_matrix], dim=0)
            self.include_rgb = True
        else:
            self.include_rgb = False

        # Precompute constants
        self.obs_history = torch.tensor([], dtype=torch.float32)

        self.i_array = torch.arange(self.lookback, 0, -1, dtype=torch.float32)
        self.decay_factors = torch.pow(self.i_array, self.decay)

        if env_cfg.get('plot', False):
            self.plot = nn_plot.NNPlot(self.datasets['retrieval'])
        else:
            self.plot = False

        if self.method == NN_METHOD.KNN:
            # Just for testing - not recommended
            self.sq_instead_of_diff = False

        self.nearest_neighbors_votes = torch.zeros(10830)

    def update_obs_history(self, current_ob):
        if len(self.obs_history) > 0:
            self.obs_history = torch.vstack((current_ob, self.obs_history))
        else:
            self.obs_history = current_ob.clone().unsqueeze(0)

    def set_obs_history(self, obs_history):
        if len(obs_history) == 0:
            self.reset_obs_history()
        else:
            self.obs_history = obs_history

    def reset_obs_history(self):
        self.obs_history = torch.tensor([], dtype=torch.float64)

class NNAgentEuclidean(NNAgent):
    #@profile
    def get_neighbors(self, current_ob):
        # Batched input [b, k, n]
        is_batched = current_ob.dim() == 3
        batch_size = current_ob.shape[0] if is_batched else 1
        if not is_batched:
            current_ob = current_ob.unsqueeze(0)
        self.obs_history = current_ob
        
        first_features = current_ob[:, :, 0]  # [b, k]
        valid_mask = ~torch.isnan(first_features)  # [b, k] - True where observations are valid
        sequence_lengths = torch.sum(valid_mask, dim=1)  # [b] - count of valid observations per batch

        current_ob = current_ob[:, -1]

        # If we have elements in our observation space that wraparound (rotations), we can't just do direct Euclidean distance
        if not hasattr(self, '_weighted_ob_buffer') or self._weighted_ob_buffer.shape[0] != batch_size:
            self._weighted_ob_buffer = torch.empty(batch_size, current_ob.shape[-1], device=current_ob.device, dtype=current_ob.dtype)

        # Copy and apply weights in one operation
        torch.mul(
            current_ob,
            self.datasets['retrieval'].weights.to(self.device),
            out=self._weighted_ob_buffer
        )
        # all_distances, dist_vecs = compute_cosine_distance(current_ob.astype(np.float64), self.processed_obs_matrix, self.rot_indices, self.non_rot_indices, self.weights[self.rot_indices])
        all_distances = torch.sqrt(torch.sum(torch.pow(torch.subtract(self.datasets['retrieval'].processed_obs_matrix[:, self.non_rot_indices].unsqueeze(0), self._weighted_ob_buffer[:, self.non_rot_indices].unsqueeze(1)), 2), dim=2))
        if len(self.rot_indices) > 0:
            all_distances += compute_distance_with_rot_torch(self._weighted_ob_buffer[self.rot_indices], self.datasets['retrieval'].processed_obs_matrix[:, self.rot_indices], self.datasets['retrieval'].weights[self.datasets['retrieval'].rot_indices])

        # When training, don't include the state itself
        if self.method == NN_METHOD.KNN or self.method == NN_METHOD.KNN_AND_DELTA:
            zero_mask = (all_distances == 0.0)
            all_distances[zero_mask] = torch.inf

        _, nearest_neighbor_indices = torch.topk(all_distances, k=self.candidates, largest=False, dim=1)
        nearest_neighbors = nearest_neighbor_indices.to(torch.int32).to(self.device)

        # Find corresponding trajectories for each neighbor
        self.datasets['retrieval'].traj_starts = self.datasets['retrieval'].traj_starts.to(self.device)

        flat_neighbors = nearest_neighbors.reshape(-1)
        flat_traj_nums = torch.searchsorted(self.datasets['retrieval'].traj_starts, flat_neighbors, right=True) - 1
        traj_nums = flat_traj_nums.reshape(batch_size, self.candidates)

        flat_obs_nums = flat_neighbors - self.datasets['retrieval'].traj_starts[flat_traj_nums]
        obs_nums = flat_obs_nums.reshape(batch_size, self.candidates)
        if self.method == NN_METHOD.NN:
            # If we're doing direct nearest neighbor, just return that action
            nearest_neighbor_idx = torch.argmin(torch.gather(all_distances, 1, nearest_neighbors.long()), dim=1)
            batch_indices = torch.arange(batch_size, device=self.device)
            actions = self.datasets['retrieval'].act_matrix[traj_nums[batch_indices, nearest_neighbor_idx], obs_nums[batch_indices, nearest_neighbor_idx]]
            
            return actions.cpu().numpy() if is_batched else actions[0].cpu().numpy()

        if self.lookback == 1 or self.obs_history.shape[-2] == 1:
            # No lookback needed
            accum_distances = torch.gather(all_distances, 1, nearest_neighbors.long())
        else:
            # How far can we look back for each neighbor trajectory?
            # This is upper bound by min(lookback hyperparameter, length of obs history, neighbor distance into its traj)
            max_lookbacks = torch.minimum(
                torch.tensor(self.lookback, dtype=torch.int64, device=self.device),
                torch.minimum(
                    obs_nums + 1,
                    sequence_lengths.unsqueeze(1).expand(-1, self.candidates)
                )
            )
            
            accum_distances = compute_accum_distance_torch(nearest_neighbors, max_lookbacks.to(self.device), self.obs_history.to(self.device), sequence_lengths, self.datasets['retrieval'].flattened_obs_matrix.to(self.device), self.decay_factors.to(self.device))

        if self.method == NN_METHOD.NS:
            # If we're doing direct nearest sequence, return that action
            nearest_sequence_idx = torch.argmin(accum_distances, dim=1)
            batch_indices = torch.arange(batch_size, device=self.device)
            actions = self.datasets['retrieval'].act_matrix[traj_nums[batch_indices, nearest_sequence_idx], obs_nums[batch_indices, nearest_sequence_idx]]

            return actions if is_batched else actions[0]

        #if DEBUG:
        #min_idx = torch.argmin(all_distances)
        # print(f"{accum_distances.min()}")
        # print(f"{len(torch.where(accum_distances == 0)[0])}")
        # print(f"{nearest_neighbors[accum_distances.argmin()]}")

        # Do a final pass and pick only the top (self.final_neighbors_ratio * 100)% of neighbors based on this new accumulated distance

        final_neighbor_num = math.floor(accum_distances.shape[1] * self.final_neighbors_ratio)
        _, final_neighbor_indices = torch.topk(accum_distances, k=final_neighbor_num, largest=False, dim=1)

        if DEBUG:
            final_neighbor_indices, _ = torch.sort(final_neighbor_indices)
            print(f"{nearest_neighbors[final_neighbor_indices]=}")

        batch_indices = torch.arange(batch_size, device=self.device).unsqueeze(1)
        final_neighbors = nearest_neighbors[batch_indices, final_neighbor_indices].to('cpu')
        # self.nearest_neighbors_votes[final_neighbors] += 1
        # print(torch.max(self.nearest_neighbors_votes))
        # if torch.max(self.nearest_neighbors_votes) == 1716:
        #     quick_plot(self.nearest_neighbors_votes, "votes")

        if self.method == NN_METHOD.KNN:
            return final_neighbors if is_batched else final_neighbors[0]
        elif self.method == NN_METHOD.KNN_AND_DELTA:
            final_distances = torch.gather(accum_distances, 1, final_neighbor_indices)
            return final_neighbors if is_batched else final_neighbors[0], final_distances if is_batched else final_distances[0]

    def to_device(self, device):
        self.device = device
        if hasattr(self, "obs_history"):
            self.obs_history = self.obs_history.to(device)
        if hasattr(self, "_weighted_ob_buffer"):
            self._weighted_ob_buffer = self._weighted_ob_buffer.to(device)

        for dataset_type in self.datasets:
            dataset = self.datasets[dataset_type]

            dataset.obs_scaler.to_device(device)
            dataset.act_scaler.to_device(device)
            if isinstance(dataset.rot_indices, torch.Tensor):
                dataset.rot_indices = dataset.rot_indices.to(device)
            if isinstance(dataset.weights, torch.Tensor):
                dataset.weights = dataset.weights.to(device)
            if isinstance(dataset.traj_starts, torch.Tensor):
                dataset.traj_starts = dataset.traj_starts.to(device)
            if isinstance(dataset.flattened_obs_matrix, torch.Tensor):
                dataset.flattened_obs_matrix = dataset.flattened_obs_matrix.to(device)
            if isinstance(dataset.flattened_act_matrix, torch.Tensor):
                dataset.flattened_act_matrix = dataset.flattened_act_matrix.to(device)
            if isinstance(dataset.processed_obs_matrix, torch.Tensor):
                dataset.processed_obs_matrix = dataset.processed_obs_matrix.to(device)

# Standard Euclidean distance, but normalize each dimension of the observation space
class NNAgentEuclideanStandardized(NNAgentEuclidean):
    #@profile
    def __init__(self, env_cfg, policy_cfg):
        self.datasets = {}
        # We may use different datasets for retrieval, neighbor state, and state delta
        if env_cfg.get('mixed', False):
            # Lookup dict for duplicate datasets
            paths = {}
            for dataset in ['retrieval', 'state', 'delta_state']:
                path = env_cfg[dataset]['demo_pkl']

                # Check for duplicates
                if path in paths.keys():
                    self.datasets[dataset] = self.datasets[paths[path]]
                else:
                    paths[path] = dataset

                    self.datasets[dataset] = load_and_scale_data(
                        path,
                        env_cfg[dataset].get('rot_indices', []),
                        env_cfg[dataset].get('weights', []),
                        ob_type=env_cfg[dataset].get('type', 'state'),
                        device=env_cfg['device']
                    )
        else:
            expert_data_path = env_cfg['demo_pkl']
            for dataset in ['retrieval', 'state', 'delta_state']:
                one_dataset = load_and_scale_data(
                    expert_data_path,
                    env_cfg.get('rot_indices', []),
                    env_cfg.get('weights', []),
                    ob_type=env_cfg.get('type', 'state'),
                    device=env_cfg['device']
                )
                self.datasets[dataset] = one_dataset

        super().__init__(env_cfg, policy_cfg)

    #@profile
    def get_neighbors(self, current_ob, normalize=True):
        current_ob = torch.clone(current_ob) if torch.is_tensor(current_ob) else torch.from_numpy(current_ob, dtype=torch.float32, device=self.device)
        is_batched = current_ob.dim() == 3

        if normalize:
            dataset = self.datasets['retrieval']
            if is_batched:
                current_ob[:, :, dataset.non_rot_indices] = dataset.obs_scaler.transform(current_ob[:, :, dataset.non_rot_indices])
            else:
                current_ob[:, dataset.non_rot_indices] = dataset.obs_scaler.transform(current_ob[:, dataset.non_rot_indices])

        return super().get_neighbors(current_ob)

    def set_obs_history(self, obs_history):
        if len(obs_history) > 0:
            obs_history = self.datasets['retrieval'].obs_scaler.transform(obs_history)
        super().set_obs_history(obs_history)
