import torch
import random 

import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from torch.utils.data.dataloader import default_collate
from torchvision import transforms

from sgcrl.models.quantizer import special_tokens


#########
# QPHIL #
#########
CURRENT = 0
FUTURE = 1
RANDOM = 2
GOAL_TYPES = [CURRENT, FUTURE, RANDOM]

def compress_tokens_tensor(tokens, k=4):
    # If the tokens are empty, do nothing
    if len(tokens) == 0:
        return []
    
    # Initialize the compressed token list with the first token
    rslt = [tokens[0]]
    for token in tokens:
        adding = True
        # If there is the current token in the k last compressed tokens, don't add it
        for i in range(1, k+1):
            if not (len(rslt) < i or token.item() != rslt[-i].item()):
                adding = False
                break
        # If the current token is not among the k last compressed tokens, add it
        if adding:
            rslt.append(token)

        # if (token.item() != rslt[-1].item()) and (len(rslt) < 2 or token.item() != rslt[-2].item()) and (len(rslt) < 3 or token.item() != rslt[-3].item()) and (len(rslt) < 4 or token.item() != rslt[-4].item()):
        #     rslt.append(token)
    return torch.tensor(rslt)

class KeyRelabeller():
    def __init__(self, key, new_key):
        self.key = key
        self.new_key = new_key

    def __call__(self, episode, idx, t, frame_size, dataset):
        episode[self.new_key] = episode[self.key]
        return episode
    
class KeyRelabellerFrame():
    def __init__(self, before, after):
        self.before = before
        self.after = after

    def __call__(self, frame):
        frame[self.after] = frame[self.before][[0,2]]
        return frame
    
class OnlyKeepDimensionsRelabeller():
    def __init__(self, keys=['sensor/position', 'goal', 'sensor/absolute_goal_position'], dimensions=[[0, 2]]):
        self.keys = keys
        self.dimensions = dimensions

    def __call__(self, episode, idx, t, frame_size, dataset):
        for key in self.keys:
            episode[key] = episode[key][:, self.dimensions].squeeze(1)

        return episode
    
class NextTokenRelabeller():
    def __init__(self, tokenizer, device, keys_to_tokenize):
        self.keys_to_tokenize = keys_to_tokenize
        self.device = device
        self.tokenizer = tokenizer

    def __call__(self, episode, _id, t, frame_size, dataset):

        to_tokenize = torch.cat([episode[key] for key in self.keys_to_tokenize]).to(self.device)
        episode['token/current'], episode['token/current/obs_representation'] = self.tokenizer.tokenize_tensor(to_tokenize)
        episode['token/current'].cpu()
        episode['token/current/obs_representation'].cpu()

        compressed_tokens = compress_tokens_tensor(episode['token/current'])
        compressed_tokens_obs_representation = self.tokenizer.get_token(compressed_tokens)

        episode['token/next'] = torch.zeros_like(episode['token/current'])
        episode['token/next/obs_representation'] = torch.zeros_like(to_tokenize.view(to_tokenize.shape[0],-1)) # flatten representation of tokens for images
        episode['token/next_next/obs_representation'] = torch.zeros_like(to_tokenize.view(to_tokenize.shape[0],-1))
        episode['token/next_next_next/obs_representation'] = torch.zeros_like(to_tokenize.view(to_tokenize.shape[0],-1))
        j = 1
        for i, token in enumerate(episode['token/current']):
            if j < len(compressed_tokens) and token == compressed_tokens[j]:
                j += 1      

            if j < len(compressed_tokens):
                episode['token/next'][i] = compressed_tokens[j]

                episode['token/next/obs_representation'][i] = compressed_tokens_obs_representation[j]
                episode['token/next_next/obs_representation'][i] = compressed_tokens_obs_representation[min(j+1, len(compressed_tokens)-1)]
                episode['token/next_next_next/obs_representation'][i] = compressed_tokens_obs_representation[min(j+2, len(compressed_tokens)-1)]

            else:
                episode['token/next'][i] = special_tokens['EOS_TOKEN']
                episode['token/next/obs_representation'][i] = episode['goal'][i].view(-1) # torch.tensor([0, 0])
                episode['token/next_next/obs_representation'][i] = episode['goal'][i].view(-1) # torch.tensor([0, 0])
                episode['token/next_next/obs_representation'][i] = episode['goal'][i].view(-1) # torch.tensor([0, 0])


        episode['goal/token'] = compressed_tokens[-1] * torch.ones_like(episode['token/current'])
        # episode['token/compressed'] = pad(torch.tensor([special_tokens['SOS_TOKEN']] + list(compressed_tokens) + [special_tokens['EOS_TOKEN']]),
        #                                 (0,self.max_length - len(compressed_tokens) - 2), 
        #                                 "constant", value=special_tokens['PADDING_VALUE']).int()

        return episode

class BatchRelabeller:
    def __init__(self, relabellers):
        self.relabellers = [default_collate] + relabellers

    @torch.no_grad()
    def __call__(self, batch):
        for relabeler in self.relabellers:
            batch = relabeler(batch)
        return batch
    
class TokenRelabeller:
    def __init__(self, tokenizer, device, keys_to_tokenize):
        self.tokenizer = tokenizer
        self.keys_to_tokenize = keys_to_tokenize
        self.device = device

    @torch.no_grad()
    def __call__(self, batch):
        # batch['obs/token'] = self.tokenizer.tokenize_tensor(torch.cat([batch[key] for key in self.keys_to_tokenize]).to(self.device))[0].squeeze(-1).cpu()
        batch['goal/token'] = torch.zeros((batch['goal'].shape[0], batch['goal'].shape[1], 1))
        for k in range(batch['goal/token'].shape[1]):
            batch['goal/token'][:, k] = self.tokenizer.tokenize_tensor(batch['goal'][:, k].to(self.device))[0].cpu()

        return batch

class ValueBatchRelabeller():
    def __init__(self, probabilities, use_obs_representation=True):
        super().__init__()
        self.probabilities = probabilities
        assert sum(probabilities) == 1, 'Sum of probabilities must be equal to 1!'
        self.use_obs_representation = use_obs_representation

    def __call__(self, batch):
        # Here:
        # (default) We give: s_t, next_token, a_t, r(s_t,next_token,a_t) = 0
        # (current) We give: s_t, token(s_t), a_t, r(s_t,token(s_t),a_t) = 1
        # (future) We give: s_t, next_token, a_t, r(s_t,next_token,a_t) = 1 iif the token of s_t+1 is next_token.
        # The default is always at probability 0, since self.probabilities sums to 1.
        batch_size = batch['goal'].shape[0]
        goal_type = np.random.choice(GOAL_TYPES, p=self.probabilities, size=(batch_size,))

        mask_current = torch.tensor(goal_type == CURRENT, dtype=torch.bool)
        mask_future = torch.tensor(goal_type == FUTURE, dtype=torch.bool)
        # mask_random = torch.tensor(goal_type == 'random', dtype=torch.bool)

        # Set next token (or representation) as goal
        rewards = torch.zeros((batch_size, 2))
        if self.use_obs_representation:
            value_goal = batch['token/next/obs_representation'].clone()
        else:
            value_goal = batch['token/next'].clone()

        # Current: set current token as goals sometimes, which helps the learning, see hiql paper: r(s_t,a_t,i_curr) = 1.
        rewards[mask_current] = torch.ones((mask_current.sum(), 2))
        if self.use_obs_representation:
            value_goal[mask_current, 0] = batch['token/current/obs_representation'][mask_current, 0]
            value_goal[mask_current, 1] = batch['token/current/obs_representation'][mask_current, 0] # nb: the 0 is not a typo
        else:
            value_goal[mask_current, 0] = batch['token/current'][mask_current, 0]
            value_goal[mask_current, 1] = batch['token/current'][mask_current, 0] # nb: the 0 is not a typo

        # Future: r(s_t,a_t,i_next) = 1 if the token of s_t+1 is i_next.
        rewards[mask_future] = (batch['token/next'][mask_future, 0] == batch['token/current'][mask_future, 1]).repeat(1, 2).float()

        # Random
        # if mask_random.any(): #Should not be used though
        #     random_goals = torch.randint(low=0, high=special_tokens['EOS_TOKEN'], size=(mask_random.astype(int).sum(), )) 
        #     batch['next_subgoal/token'][mask_random, 0] = random_goals
        #     batch['next_subgoal/token'][mask_random, 1] = random_goals
        #     rewards[mask_random] = (batch['token/ids'][mask_random] == batch['next_subgoal/token'][mask_random]).float()

        batch['reward'] = rewards
        batch['value/goal'] = value_goal

        batch['value_obs'] = torch.cat((value_goal, batch['obs/partial']), dim=2)

        return batch

# class ValueEpisodeRelabeller():
#     def __init__(self, probabilities):
#         super().__init__()
#         self.probabilities = probabilities
#         self.use_obs_representation

#     def __call__(self, episode, _id, t, frame_size, dataset):
#         goal_type = np.random.choice(GOAL_TYPES, p=self.probabilities)
#         if goal_type == CURRENT:
#             idx = t
#             source_episode = episode
#         elif goal_type == FUTURE:
#             ep_len = episode[next(episode.keys().__iter__())].size()[0]
#             idx = min(t + np.ceil(np.log(1 - np.random.rand()) / np.log(0.99)).astype(int), ep_len - 1,)
#             source_episode = episode
#         else: #Random
#             rnd_frame_idx = np.random.randint(len(dataset))
#             ep_id, idx = dataset._index[rnd_frame_idx]
#             source_episode = dataset._get_episode(ep_id)

#         # Update reward
#         if goal_type == CURRENT:
#             episode['reward'] = torch.tensor([1.0]).unsqueeze(1)
#         elif goal_type == FUTURE:
#             episode['reward'] = torch.tensor([1.0 if t == idx else 0]).unsqueeze(1)
#         else:
#             episode['reward'] = torch.tensor([1.0 if t == idx and _id[1] == ep_id[1] else 0]).unsqueeze(1)

#         # Update goal   
#         episode['value/goal/goal'] = source_episode['obs/partial'][idx].unsqueeze(0)    
#         episode['value_obs/goal'] = torch.cat((episode['value/goal/goal'], episode['obs/partial']), dim=1)

#         return episode

class ValueBatchRelabellerGodot():
    def __init__(self, probabilities, use_obs_representation=True, use_raycasts=False):
        super().__init__()
        self.probabilities = probabilities
        self.use_obs_representation = use_obs_representation
        self.use_raycasts = use_raycasts

    def __call__(self, batch):
        batch_size = batch['goal'].shape[0]
        goal_type = np.random.choice(GOAL_TYPES, p=self.probabilities, size=(batch_size,))

        mask_current = torch.tensor(goal_type == CURRENT, dtype=torch.bool)
        mask_future = torch.tensor(goal_type == FUTURE, dtype=torch.bool)
        # mask_random = torch.tensor(goal_type == 'random', dtype=torch.bool)

        rewards = torch.zeros((batch_size, 2))
        if self.use_obs_representation:
            value_goal = batch['token/next/obs_representation'].clone()
        else:
            value_goal = batch['token/next'].clone()

        # Current: set current token as goals sometimes, which helps the learning, see hiql paper.
        rewards[mask_current] = torch.ones((mask_current.sum(), 2))
        if self.use_obs_representation:
            value_goal[mask_current, 0] = batch['token/current/obs_representation'][mask_current, 0]
            value_goal[mask_current, 1] = batch['token/current/obs_representation'][mask_current, 0] # nb: the 0 is not a typo
        else:
            value_goal[mask_current, 0] = batch['token/current'][mask_current, 0]
            value_goal[mask_current, 1] = batch['token/current'][mask_current, 0] # nb: the 0 is not a typo

        # Future: r(s_t,a_t,i_next) = 1 iif the token of s_t+1 is i_next.
        rewards[mask_future] = (batch['token/next'][mask_future, 0] == batch['token/current'][mask_future, 1]).repeat(1, 2).float()

        # Random
        # if mask_random.any(): #Should not be used though
        #     random_goals = torch.randint(low=0, high=special_tokens['EOS_TOKEN'], size=(mask_random.astype(int).sum(), )) 
        #     batch['next_subgoal/token'][mask_random, 0] = random_goals
        #     batch['next_subgoal/token'][mask_random, 1] = random_goals
        #     rewards[mask_random] = (batch['token/ids'][mask_random] == batch['next_subgoal/token'][mask_random]).float()

        batch['reward'] = rewards
        batch['value/goal'] = value_goal

        goal = batch['sensor/absolute_goal_position']
        ball=batch["sensor/position"]
        if ball.shape[2] != 2:
            ball = ball[:, :, [0,2]]
        if goal.shape[2] != 2:
            goal = goal[:, :, [0,2]]
        cos_rot=batch["sensor/rotation"].cos()
        sin_rot=batch["sensor/rotation"].sin()
        # print(batch['sensor/raycasts'].shape)
        raycasts = batch["sensor/raycasts"].view(batch["sensor/raycasts"].shape[0], 2, -1)
        # print(ball.shape, cos_rot.shape, sin_rot.shape, raycasts.shape, value_goal.shape)
        # print(ball.shape, cos_rot.shape, sin_rot.shape, value_goal.shape)
        batch['value_obs'] = torch.cat([ball, cos_rot, sin_rot, raycasts, value_goal], dim=2)   

        batch['obs_subgoal/obs_representation'] = torch.cat(([ball, cos_rot, sin_rot, raycasts, goal]), dim=2)

        # Build observation
        if self.use_raycasts:
            batch['observation'] = torch.cat(([ball, cos_rot, sin_rot, raycasts]), dim=2)
        else:
            batch['observation'] = torch.cat(([ball, cos_rot, sin_rot]), dim=2)

        # Build action
        mr = batch["action/move_right"].unsqueeze(-1)
        ml = batch["action/move_left"].unsqueeze(-1)
        mf = batch["action/move_forwards"].unsqueeze(-1)
        mb = batch["action/move_backwards"].unsqueeze(-1)
        move = torch.cat([mr, ml, mf, mb], dim=1)
        cos_r = batch["action/rotation"].cos().unsqueeze(-1)
        sin_r = batch["action/rotation"].sin().unsqueeze(-1)
        cos_sin_r = torch.cat([cos_r, sin_r], dim=1)
        batch['action'] = torch.cat([move,cos_sin_r], dim=1)

        return batch

class ObsRelabeller:
    def __init__(self, obs_keys, partial_obs_keys):
        self.obs_keys = obs_keys
        self.partial_obs_keys = partial_obs_keys

    def __call__(self, batch):
        # Quick hack for concat of raycasts
        if 'sensor/raycasts' in batch.keys() and 'sensor/position' in batch.keys():
            B, F, _ = batch['sensor/position'].shape
            batch['sensor/raycasts'] = batch['sensor/raycasts'].view(B, F, -1)
        batch['obs/complete'] = torch.cat([batch[obs] for obs in self.obs_keys], dim=-1)
        batch['obs/partial'] = torch.cat([batch[obs] for obs in self.partial_obs_keys], dim=-1)
        return batch
    
class ObsGoalRelabeller:
    def __init__(self):
        pass

    def __call__(self, batch):
        batch['obs_goal'] = torch.cat([batch['goal'], batch['obs/complete']], dim=-1)
        batch['obs_subgoal'] = torch.cat([batch['token/next'], batch['obs/partial']], dim=-1)
        batch['obs_subgoal/obs_representation'] = torch.cat([batch['token/next/obs_representation'], batch['obs/partial']], dim=-1)
        return batch
    
class PosOtherRelabeller():
    def __init__(self):
        pass

    def __call__(self, frame):
        frame['obs/pos'] = frame['observation'][:2]
        frame['obs/other'] = frame['observation'][2:]
        return frame
    
##########
# Images #
##########
def get_params(img, output_size):
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL Image or Tensor): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
        _, h, w = TF.get_dimensions(img)
        th, tw = output_size

        if h < th or w < tw:
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")

        if w == tw and h == th:
            return 0, 0, h, w

        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()
        return i, j, th, tw

def crop_padding(
        img, 
        size, 
        padding=None, 
        pad_if_needed: bool = False,
        fill: int = 0,
        padding_mode="constant"
    ):
    """
    Args:
        img (PIL Image or Tensor): Image to be cropped.

    Returns:
        PIL Image or Tensor: Cropped image.
    """
    mask = torch.ones_like(img)
    if padding is not None:
        mask = TF.pad(mask, padding, fill, padding_mode)
        img = TF.pad(img, padding, fill, padding_mode)

    _, height, width = TF.get_dimensions(img)
    # pad the width if needed
    if pad_if_needed and width < size[1]:
        padding = [size[1] - width, 0]
        mask = F.pad(mask, padding, fill, padding_mode)
        img = F.pad(img, padding, fill, padding_mode)
    # pad the height if needed
    if pad_if_needed and height < size[0]:
        padding = [0, size[0] - height]
        mask = F.pad(mask, padding, fill, padding_mode)
        img = F.pad(img, padding, fill, padding_mode)

    i, j, h, w = get_params(img, size)

    return TF.crop(img, i, j, h, w), TF.crop(mask, i, j, h, w)

def image_data_augmentation(
    p_aug,
    x,
    square_rotation=False,
    vertical_flip=False,
    horizontal_flip=False,
    padded_random_crop=False,
    padding_size=4,
    reshaped_random_crop=False,
    crop_size=(16,16)

):

    # check if we augment
    if np.random.rand() > p_aug:
        return x, torch.ones_like(x)
    
    # otherwise we augment
    x = x.permute(0,3,1,2)
    x_size = (x.shape[-1],x.shape[-2])
    mask = torch.ones_like(x)

    # square rotation
    if square_rotation:
        angle = random.choice([0, 90, 180, 270])
        x = transforms.functional.rotate(x, angle)

    # vertical flip
    if vertical_flip:
        x = transforms.RandomVerticalFlip().forward(x)

    # horizontal flip
    if horizontal_flip:
        x = transforms.RandomHorizontalFlip().forward(x)

    # padded random crop
    if padded_random_crop:
        x, mask = crop_padding(x,x_size,padding=padding_size)

    # resized random crop
    if reshaped_random_crop:
        x, _ = crop_padding(x,crop_size,padding=0)
        x = F.interpolate(torch.tensor(x), size=x_size, mode='bilinear', align_corners=False)
    
    x = x.permute(0,2,3,1)
    mask = mask.permute(0,2,3,1)

    return x, mask