import torch
import random 

import numpy as np

from tqdm import tqdm
from torch.nn.functional import pad
from torch.utils.data import Dataset, DataLoader

from sgcrl.data.dbs.on_disk import load_model
from sgcrl.data.torch_datasets.relabel import compress_tokens_tensor
from sgcrl.models.quantizer import special_tokens


class ContrastiveTripletsDataset(Dataset):
    def __init__(self, pytorch_episodes_db, window, random_negative, cache_size=100000, keys_to_tokenize=["obs/pos"], is_godot=False):
        self._episodes_db = pytorch_episodes_db
        self._index = []     # all (episode id, timestep, positive timesteps list, negative timesteps list)
        self._cache = {}     # cache of episodes {episode id: torch tensor of episode key to tokenize}
        self._idx_cache = [] # list of ids of episodes in cache
        self._cache_size = cache_size
        self.window = window
        self.keys_to_tokenize = keys_to_tokenize
        self.is_godot = is_godot
        self._run_index()
        self.cache_init = True # the cache has been initialized by _load_cache
        self._load_cache()
        self.random_negative = random_negative # take a random frame from random episode as negative example instead of random faraway frame from the same episode

    def _get_episode(self, idx):
        if idx in self._cache:
            return self._cache[idx]
        else:
            if self.is_godot:
                e = torch.cat([self._episodes_db[idx][key] for key in self.keys_to_tokenize])[:, [0, 2]].contiguous()
            else:
                e = torch.cat([torch.tensor(self._episodes_db[idx][key]) for key in self.keys_to_tokenize]).contiguous()

            self._cache[idx] = e
            self._idx_cache.append(idx)
            while len(self._idx_cache) > self._cache_size:
                del_idx = self._idx_cache.pop(0)
                del self._cache[del_idx]
        return self._cache[idx]

    # used to pre-load worker's cache
    # can be useful when launching experiments with slow disk access (like on our cluster)
    def _load_cache(self):
        print("Pre-loading cache to minimize disk access during training")
        for id in tqdm(self._ids[:self._cache_size]):
            self._get_episode(id)

    #@torch.no_grad()
    def _run_index(self):
        self._index = []
        self._ids = [i for i in self._episodes_db.get_ids()]
        self._lengths = {}
        self._index = []
        for _id in tqdm(self._ids):

            # get the episode (dict of tensors)
            d = self._episodes_db[_id] #self._get_episode(_id)

            # if the episode is too short for the window, skip it
            if d[self.keys_to_tokenize[0]].shape[0] < 2*self.window + 2:
                continue
            
            # get the length of the episode
            _, tensor = next(d.items().__iter__())
            T = tensor.shape[0]
            self._lengths[_id] = T

            # compute the indexes as all the timesteps of the episode
            for t in range(T):
                self._index.append((_id, t, [], []))

    def _get_frame(self, _id, t, positive_index, negative_index):
        episode = self._get_episode(_id)
        # print(episode)
        if self.random_negative:
            negative_episode = self._get_episode(self._index[np.random.randint(0, len(self._index))][0])
            return episode[t], episode[positive_index],  negative_episode[np.random.randint(0, negative_episode.shape[0])]
        else:
            return episode[t], episode[positive_index],  episode[negative_index]

    def __getitem__(self, idx):
        if not self.cache_init:
            self._load_cache()
            self.cache_init = True
        _id, t, positive_indexes, negative_indexes = self._index[idx]
        if len(positive_indexes) == 0:
            d = self._get_episode(_id)
            # if there is no pos and neg timesteps, sample 10 of them and set them in self._index
            positive_indexes = list(np.random.randint(max(0, t - self.window), min(t + self.window + 1, d.shape[0]), 10))
            negative_indexes = list(np.random.choice(list(range(t - self.window)) + list(range(t + self.window + 1, d.shape[0])), 10))
            self._index[idx] = _id, t, positive_indexes, negative_indexes
        # pop a sample of positive and negative examples
        r = self._get_frame(_id, t, positive_indexes.pop(), negative_indexes.pop())
        return r

    def __len__(self):
        return len(self._index)

class ContrastiveTripletsDatasetVisual(Dataset):
    def __init__(
        self,
        episodes_reader,
        window,
        target_slice,
        cache_size=100000,
        device="cpu",
        is_godot=False
    ):
        self.episodes_reader = episodes_reader
        self.n_episodes = len(episodes_reader)
        self.index = []
        self.cache = {}
        self.idx_cache = []
        self.cache_size = cache_size
        self.device=device
        self.window = window
        self.is_godot = is_godot

        if target_slice is None:
            self.target_slice = target_slice
        else:
            self.target_slice = np.array(target_slice)
    
    @torch.no_grad()
    def __getitem__(self, idx):
        length = self.episodes_reader.lengths[idx]
        
        sample_index = np.random.randint(length-1)
        sample_frame = self.episodes_reader.get_episode_frame(idx,sample_index)
        sample = sample_frame['observations']
        state_sample = sample_frame['state_observations']

        positive_index = np.random.randint(max(0, sample_index - self.window), min(sample_index + self.window + 1, length))
        positive_frame = self.episodes_reader.get_episode_frame(idx,positive_index)
        positive = positive_frame['observations']
        state_positive = positive_frame['state_observations']

        random_idx = np.random.randint(self.n_episodes)
        negative_index = np.random.randint(self.episodes_reader.lengths[random_idx]-1)
        negative_frame = self.episodes_reader.get_episode_frame(random_idx,negative_index)
        negative = negative_frame['observations']
        state_negative = negative_frame['state_observations']

        # negative_indices = list(range(sample_index - self.window)) + list(range(sample_index + self.window + 1, length))
        # negative_frame = self.episodes_reader.get_episode_frame(idx,np.random.choice(negative_indices))
        # negative = negative_frame['observations']
        # state_negative = negative_frame['state_observations']

        if self.target_slice is not None:
            sample = sample[:,self.target_slice]
            positive = positive[:,self.target_slice]
            negative = negative[:,self.target_slice]

        return sample, positive, negative
        

    def __len__(self):
        return len(self.episodes_reader)

class TransformerDataset(Dataset):
    def __init__(
        self,
        pytorch_episodes_db,
        keys_to_tokenize,
        tokenizer,
        augment_dataset_prob,
        remove_cycles,
        stitching,
        cache_size=np.inf,
        is_godot=False
    ):
        self._episodes_db = pytorch_episodes_db
        self._index = [] # list of episodes ids
        self._cache = {}
        self._idx_cache = []
        self._cache_size = cache_size
        self.keys_to_tokenize = keys_to_tokenize
        self.device = "cpu"
        self.tokenizer = tokenizer
        self.augment_dataset_prob = augment_dataset_prob
        self.remove_cycles = remove_cycles
        self.stitching = stitching
        self.max_length = 158
        self.is_godot = is_godot

        self.eos_token = special_tokens['EOS_TOKEN']
        self.sos_token = special_tokens['SOS_TOKEN']
        self.padding_value = special_tokens['PADDING_VALUE']

        self._run_index()

    def _remove_cycles(self, tensor):
        # Dictionary to store the first occurrence of each element
        element_indices = {}
        mask = torch.tensor([True] * tensor.shape[0])
        for idx, element in enumerate(tensor):
            if element.item() in element_indices and len([1 for idx in element_indices[element.item()] if mask[idx]]):
                # If the element is seen before, mark all elements in the cycle
                start_idx = [idx for idx in element_indices[element.item()] if mask[idx]][0]
                mask[start_idx + 1:idx+1] = False
            else:
                # Record the first occurrence of the element
                if not element.item() in element_indices:
                    element_indices[element.item()] = [idx]        
                else:
                    element_indices[element.item()].append(idx)    
        return tensor[mask]

    def _relabel(self, episode):
        if self.is_godot:
            episode['token/current'] = self.tokenizer.tokenize_tensor(episode['sensor/position'][:, [0,2]].to(self.device))[0].cpu()
        else:
            episode['token/current'] = self.tokenizer.tokenize_tensor(torch.cat([episode[key] for key in self.keys_to_tokenize]).to(self.device))[0].cpu()
        
        compressed_tokens = compress_tokens_tensor(episode['token/current'])

        episode['token/next'] = torch.zeros_like(episode['token/current'])
        j = 1
        for i, token in enumerate(episode['token/current']):
            if j < len(compressed_tokens) and token == compressed_tokens[j]:
                j += 1      

            if j < len(compressed_tokens):
                episode['token/next'][i] = compressed_tokens[j]
            else:
                episode['token/next'][i] = self.eos_token

        episode['goal/token'] = compressed_tokens[-1] * torch.ones_like(episode['token/current'])
        episode['token/compressed'] = compressed_tokens
        # remove tokens that are the same as the goal token
        episode['token/compressed'] = episode['token/compressed'][:(episode['token/compressed'] == episode['goal/token'][[0]].item()).nonzero().squeeze().int().view(-1)[0] + 1]
        if self.remove_cycles:
            episode['token/compressed'] = self._remove_cycles(episode['token/compressed'])
        # episode['token/compressed'] = pad(torch.tensor([self.sos_token] + list(episode['token/compressed']) + [self.eos_token]),
                                        # (0,self.max_length - len(episode['token/compressed']) - 2), 
                                        # "constant", value=self.padding_value).int()
        episode = {'token/compressed': episode['token/compressed'], 'goal/token': episode['goal/token'][[0]]}
        return episode
    
    def _pad(self, episode):
        return {'token/compressed': pad(torch.tensor([self.sos_token] + list(episode['token/compressed']) + [self.eos_token]),
                                        (0,self.max_length - len(episode['token/compressed']) - 2), 
                                        "constant", value=self.padding_value).int(),
                'goal/token': episode['goal/token']}

    def _get_episode(self, idx):
        if idx in self._cache:        
            return self._cache[idx]
        else:
            e = self._episodes_db[idx]
            # if necessary, convert numpy arrays into torch tensors
            for k,v in e.items():
                if isinstance(v,np.ndarray):
                    e[k] = torch.from_numpy(v)
            e = self._relabel(e)
            self._cache[idx] = {k: v.contiguous() for k,v in e.items()}
            self._idx_cache.append(idx)
            while len(self._idx_cache) > self._cache_size:
                idx = self._idx_cache.pop(0)
                del self._cache[idx]
        return self._cache[idx]

    def _run_index(self):
        self._index = [(i, i, 0) for i in self._episodes_db.get_ids()]
        if not self.stitching:
            return
        
        seen = set()
        for i in tqdm(self._episodes_db.get_ids()):
            for j in self._episodes_db.get_ids():
                if i == j:
                    continue
                
                episode_i = self._get_episode(i)
                episode_j = self._get_episode(j)

                matches = (episode_i['goal/token'][0] == episode_j['token/compressed'][:-1]).nonzero().squeeze().int()
                if matches.numel():
                    if self.remove_cycles:
                        seq = tuple(self._remove_cycles(torch.cat((episode_i['token/compressed'], episode_j['token/compressed'][matches.view(-1)[-1] + 1:]))).numpy())
                    else:
                        seq = tuple(torch.cat((episode_i['token/compressed'], episode_j['token/compressed'][matches.view(-1)[-1] + 1:])).numpy())

                    if not seq in seen:
                        seen.add(seq)
                        self._index.append((i, j, matches.view(-1)[-1])) 

    @torch.no_grad()
    def __getitem__(self, idx):
        _idx1, _idx2, stich_idx = self._index[idx]
        # print(self._get_episode(_idx1)['token/compressed'], self._get_episode(_idx2)['token/compressed'], stich_idx)
        if _idx1 != _idx2:
            episode = {'token/compressed': torch.cat((self._get_episode(_idx1)['token/compressed'], self._get_episode(_idx2)['token/compressed'][stich_idx+1:])), 
                    'goal/token': self._get_episode(_idx2) ['goal/token'][[0]]}
        else:
            episode = self._get_episode(_idx1)

        if self.remove_cycles:
            episode['token/compressed'] = self._remove_cycles(episode['token/compressed'])

        # print(self._pad(episode)['goal/token'].shape, self._pad(episode)['token/compressed'].shape)
        return self._pad(episode)


    def __len__(self):
        return len(self._index)
    
class PytorchEpisodeFrameDatasetComplexRelabel(Dataset):
    def __init__(self, pytorch_episodes_db, frame_size, cache_size=100000, episode_relabellers=[], frame_relabellers=[]):
        self._episodes_db = pytorch_episodes_db
        self._frame_size = frame_size
        self._index = []
        self._cache = {}
        self._idx_cache = []
        self._cache_size = cache_size
        self._episode_relabellers = episode_relabellers
        self.frame_relabellers = frame_relabellers
        self._run_index()
        self.cache_init = True
        self._load_cache()

    def _get_episode(self, idx):
        if idx in self._cache:
            return self._cache[idx]
        else:
            e = self._episodes_db[idx]
            # if necessary, convert numpy arrays into torch tensors
            for k,v in e.items():
                if isinstance(v,np.ndarray):
                    e[k] = torch.from_numpy(v)
                    
            for r in self._episode_relabellers:
                e = r(e, None, None, None, None)

            self._cache[idx] = e
            self._idx_cache.append(idx)
            while len(self._idx_cache) > self._cache_size:
                del_idx = self._idx_cache.pop(0)
                del self._cache[del_idx]
        return self._cache[idx]

    # used to pre-load worker's cache
    # can be useful when launching experiments with slow disk access (like on our cluster)
    def _load_cache(self):
        print("pre-loading cache to minimize disk access during training")
        for id in tqdm(self._ids[:self._cache_size]):
            self._get_episode(id)

    #@torch.no_grad()
    def _run_index(self, include_terminals=False):
        self._index = []
        self._ids = [i for i in self._episodes_db.get_ids()]

        self._lengths = {}
        self._index = []
        # print("Indexing dataset...")
        for _id in tqdm(self._ids):
            d = self._get_episode(_id)
            _, tensor = next(d.items().__iter__())
            T = tensor.size()[0]
            self._lengths[_id] = T
            assert (
                self._frame_size <= T
            ), "Episodes are too short for the frame size that has been chosen"
            for i in range(T - self._frame_size + 1):
                if include_terminals or 'token/next' not in d or d['token/next'][i] != special_tokens['EOS_TOKEN']:
                    self._index.append((_id, i))
        # print("...done")
        # print(len(self._index))

    def _split(self, episode, t):
        results = {}
        for k, v in episode.items():
            if len(v) == 1:  # relabelled data is already sliced
                results[k] = v
            else:
                results[k] = v[t : t + self._frame_size]

        return results

    def _get_frame(self, _id, t):
        #episode = copy.deepcopy(self._get_episode(_id))
        episode = self._get_episode(_id)

        for relabeller in self.frame_relabellers:
            episode = relabeller(episode, _id, t, self._frame_size, self)

        return self._split(episode, t)

    def __getitem__(self, idx):
        if not self.cache_init:
            self._load_cache()
            self.cache_init = True
        _id, t = self._index[idx]
        r = self._get_frame(_id, t)
        return r

    def __len__(self):
        return len(self._index)
    
