import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import clip
import tree
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights

import promptrl.archs as archs
import promptrl.utils as utils

class ObsEmbedPrompt(nn.Module):
    def __init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim, shared_proj=False, hidden_layers=1, dropout_prob=0.0, lm_embed_init=False):
        super().__init__()
        self.soft_prompt_dim = soft_prompt_dim
        self.deep_prompt_dim = deep_prompt_dim
        self.num_tasks = num_tasks

        self._lm_wte = (model._get_wte(),)

        self.begin_unit = archs.MTUnit(num_tasks, model, 0, deep_prompt_dim=deep_prompt_dim, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)
        self.wrap_unit = archs.MTUnit(num_tasks, model, soft_prompt_dim, dropout_prob=dropout_prob)

    def _per_obs_tokens(self, n_obs, n_goal_toks, pos):
        raise NotImplementedError

    def _wrap_obs_prompt(self, task_id, obs_embeds, pos, goal):
        wrap_prompt = self.wrap_unit(task_id)['input_embeds']
        chunks = torch.chunk(wrap_prompt, 2, dim=0)
        embeds = torch.cat((chunks[0], obs_embeds, chunks[1]), dim=0)
        if pos == 0:
            embeds = torch.cat((self._lm_wte[0](goal), embeds), dim=0)
        return embeds

    def _prompt_attn_mask(self, obs_embeds, pos):
        attn_len = obs_embeds.shape[0]
        if pos == 0:
            attn_len += self.deep_prompt_dim
        return torch.ones((attn_len,), device=obs_embeds.device)

    def _process_obs_embeds(self, task_ids, obs_embeds_tree, obs_pos, goals):
        prompt_embeds = tree.map_structure(self._wrap_obs_prompt, task_ids, obs_embeds_tree, obs_pos, goals)
        attn_masks = tree.map_structure(self._prompt_attn_mask, prompt_embeds, obs_pos)
        return prompt_embeds, attn_masks

    def _encode_obs(self, task_ids, obs, goals, obs_precomputed=None, goals_precomputed=None):
        raise NotImplementedError

    def fill_sequence(self, task_ids, obs, goals, fill_arr, fill_locs, obs_precomputed=None, goals_precomputed=None, **kwargs):
        assert len(fill_locs) == fill_arr.shape[0] == len(goals)
        assert len(fill_locs) == (len(obs) if obs is not None else len(obs_precomputed))
        obs_pos = [list(range(len(seq))) for seq in (obs if obs is not None else obs_precomputed)]
        prompts = self.forward(task_ids, obs, obs_pos, goals, sequence_begin=True, obs_precomputed=obs_precomputed, goals_precomputed=goals_precomputed)
        assert len(fill_locs) == len(prompts['inputs_embeds'])

        fill_arr_embeds = self._lm_wte[0](fill_arr)
        for i, (prompt_row, locs_row) in enumerate(zip(prompts['inputs_embeds'], fill_locs)):
            for prompt, (start_idx, end_idx) in zip(prompt_row, locs_row):
                assert len(prompt) == end_idx - start_idx
                assert fill_arr[i, start_idx:end_idx].sum().item() == 0
                fill_arr_embeds[i, start_idx:end_idx] = prompt
        return {
            'inputs_embeds': fill_arr_embeds,
            'past_key_values': prompts['past_key_values'],
        }

    def forward(self, task_ids, obs, obs_pos, goals=None, obs_precomputed=None, goals_precomputed=None, **kwargs):
        batch_size = len(obs) if obs is not None else len(obs_precomputed)

        obs_embeds = self._encode_obs(utils.unravel_as(task_ids, obs_pos), obs, goals, obs_precomputed, goals_precomputed)
        prompt_embeds, attention_mask = self._process_obs_embeds(utils.unravel_as(task_ids, obs_pos), obs_embeds, obs_pos, utils.unravel_as(goals, obs_pos))

        if obs_pos[0] == 0 or (isinstance(obs_pos[0], list) and obs_pos[0][0] == 0):
            begin_prompt = self.begin_unit(task_ids)['past_key_values']
        else:
            begin_prompt = None
        return {
            'inputs_embeds': prompt_embeds,
            'attention_mask': attention_mask,
            'past_key_values': begin_prompt,
        }

class TextUnitPrompt(ObsEmbedPrompt):
    def __init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim, obs_embed_dim, dropout_prob=0.0, lm_embed_init=False):
        #super().__init__(model, soft_prompt_dim, deep_prompt_dim)
        nn.Module.__init__(self)
        self.soft_prompt_dim = soft_prompt_dim
        self.deep_prompt_dim = deep_prompt_dim
        self.obs_embed_dim = obs_embed_dim
        self.num_tasks = num_tasks

        self._lm_embed_dim = model.config.n_embd
        self._lm_wte = (model._get_wte(),)

        self.begin_unit = archs.MTUnit(num_tasks, model, 0, deep_prompt_dim=deep_prompt_dim, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)
        self.wrap_begin_unit = archs.MTUnit(num_tasks, model, soft_prompt_dim, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)
        self.wrap_unit = archs.MTUnit(num_tasks, model, obs_embed_dim, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)

    def _per_obs_tokens(self, n_obs, n_goal_toks, pos):
        return n_obs + ((n_goal_toks + self.soft_prompt_dim) if pos == 0 else self.obs_embed_dim)

    def _encode_obs(self, task_ids, obs, goals, obs_precomputed=None, goals_precomputed=None):
        return tree.map_structure(self._lm_wte[0], obs)

    def _wrap_obs_prompt(self, task_id, obs_embeds, pos, goal):
        if pos == 0:
            wrap_prompt = self.wrap_begin_unit(task_id)['input_embeds']
        else:
            wrap_prompt = self.wrap_unit(task_id)['input_embeds']
        chunks = torch.chunk(wrap_prompt, 2, dim=0)
        embeds = torch.cat((chunks[0], obs_embeds, chunks[1]), dim=0)
        if pos == 0:
            embeds = torch.cat((self._lm_wte[0](goal), embeds), dim=0)
        return embeds

class ImgDummyPrompt(ObsEmbedPrompt):
    def __init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim, dropout_prob=0.0, lm_embed_init=False):
        super().__init__(num_tasks, model, soft_prompt_dim, deep_prompt_dim, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)
        self._lm_embed_dim = model.config.n_embd

    def _per_obs_tokens(self, n_obs, n_goal_toks, pos):
        return self.soft_prompt_dim + (n_goal_toks if pos == 0 else 0)

    def _encode_obs(self, task_ids, obs, goals, obs_precomputed=None, goals_precomputed=None):
        _ret_zero = lambda _: torch.zeros((0, self._lm_embed_dim), device=goals[0].device)
        return tree.map_structure(_ret_zero, obs if obs is not None else obs_precomputed)

class CLIPEmbedPrompt(ObsEmbedPrompt):
    def __init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim, obs_embed_dim, shared_proj=False, hidden_layers=2, dropout_prob=0.0, lm_embed_init=False, clip_arch='ViT-B/32', clip_batch_size=8, patched=False, patch_dim=(3, 3), detach_obs_forward=False):
        super().__init__(num_tasks, model, soft_prompt_dim, deep_prompt_dim, shared_proj=shared_proj, hidden_layers=hidden_layers, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)

        self.patched = patched
        self.patch_dim = patch_dim
        self.obs_embed_dim = obs_embed_dim
        self.detach_obs_forward = detach_obs_forward

        self.clip_arch = clip_arch
        self.obs_batch_size = clip_batch_size
        if self.clip_arch.startswith('RN'):
            self.clip_hidden_dim = 1024
        elif self.clip_arch.startswith('ViT'):
            self.clip_hidden_dim = 512
        else:
            raise NotImplementedError

        # TODO make device selection for clip with other models through accelerate
        self.clip_device = "cuda" if torch.cuda.is_available() else "cpu"
        clip_model, self.preprocess = clip.load(self.clip_arch, device=self.clip_device)
        self.clip_model = (clip_model,)

        self.obs_embed = archs.Projection(model, self.clip_hidden_dim, obs_embed_dim, 0, n_layers=hidden_layers, shared_proj=shared_proj, dropout_prob=dropout_prob)

    def _per_obs_tokens(self, n_obs, n_goal_toks, pos):
        return self.soft_prompt_dim + self.obs_embed_dim * n_obs + (n_goal_toks if pos == 0 else 0)

    @torch.no_grad()
    def _encode_clip(self, chunk):
        chunk = chunk.to(self.clip_device)
        return self.clip_model[0].encode_image(chunk).float().detach()

    def _preprocess_imgs(self, im):
        if self.patched:
            im = utils.patchify_np(im, self.patch_dim)
        assert im.shape[-1] == 3# channel at end
        im = [Image.fromarray(im[i]) for i in range(im.shape[0])]
        return im

    def _postprocess_imgs(self, obs_embeds):
        return self.obs_embed(obs_embeds)['input_embeds']

    def _task_detach(self, task_id, obs_embed):
        if task_id == 0:
            return obs_embed.detach()
        else:
            return obs_embed

    def _encode_obs(self, task_ids, obs, goals, obs_precomputed=None, goals_precomputed=None):
        if obs_precomputed is None:
            obs_spl = tree.map_structure(self._preprocess_imgs, obs)
            obs_spl = tree.map_structure(self.preprocess, obs_spl)
            obs_embeds = utils.map_structure_batched(self._encode_clip, obs_spl, post_func=self._postprocess_imgs, batch_size=self.obs_batch_size)
            obs_embeds = tree.map_structure_up_to(obs, torch.cat, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
        else:
            obs = obs_precomputed
            obs_spl = tree.map_structure(list, obs)# tree of 300 x 300 x 3
            obs_embeds = utils.map_structure_stacked(self._postprocess_imgs, obs_spl)
            obs_embeds = tree.map_structure_up_to(obs, torch.cat, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
        if self.detach_obs_forward:
            obs_embeds = tree.map_structure(self._task_detach, task_ids, obs_embeds)
        return obs_embeds

class ResNetEmbedPrompt(CLIPEmbedPrompt):
    def __init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim, obs_embed_dim, shared_proj=False, hidden_layers=2, dropout_prob=0.0, lm_embed_init=False, resnet_batch_size=16, patched=False, patch_dim=(3, 3), detach_obs_forward=False, pretrained=False, freeze_resnet=False):
        ObsEmbedPrompt.__init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim, shared_proj=shared_proj, hidden_layers=hidden_layers, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)
        self.patched = patched
        self.patch_dim = patch_dim
        self.obs_embed_dim = obs_embed_dim
        self.detach_obs_forward = detach_obs_forward
        self.pretrained = pretrained
        self.obs_batch_size = resnet_batch_size
        self.freeze_resnet = freeze_resnet

        weights = ResNet50_Weights.DEFAULT
        self.preprocess = weights.transforms()
        if self.pretrained:
            self.resnet = resnet50(weights=weights)
        else:
            self.resnet = resnet50(weights=None)
        self.resnet.fc = nn.Identity()
        if self.freeze_resnet:
            for param in self.resnet.parameters():
                param.requires_grad = False

        self.obs_embed = archs.Projection(model, 2048, obs_embed_dim, 0, n_layers=hidden_layers, shared_proj=shared_proj, dropout_prob=dropout_prob)

    def _encode_resnet(self, chunk):
        chunk = chunk.to(self.obs_embed.proj[-1].weight.device)
        if self.freeze_resnet:
            with torch.no_grad():
                embed = self.resnet(chunk)
            return embed.detach()
        else:
            embed = self.resnet(chunk)
            return embed

    def _encode_obs(self, task_ids, obs, goals, obs_precomputed=None, goals_precomputed=None):
        if obs_precomputed is None:
            obs_spl = tree.map_structure(self._preprocess_imgs, obs)
            obs_spl = tree.map_structure(self.preprocess, obs_spl)
            obs_embeds = utils.map_structure_batched(self._encode_resnet, obs_spl, post_func=self._postprocess_imgs, batch_size=self.obs_batch_size)
            obs_embeds = tree.map_structure_up_to(obs, torch.cat, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
        else:
            obs = obs_precomputed
            obs_spl = tree.map_structure(list, obs)# tree of 300 x 300 x 3
            obs_embeds = utils.map_structure_stacked(self._postprocess_imgs, obs_spl)
            obs_embeds = tree.map_structure_up_to(obs, torch.cat, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
        if self.detach_obs_forward:
            obs_embeds = tree.map_structure(self._task_detach, task_ids, obs_embeds)
        return obs_embeds

class CLIPGoalPrompt(ObsEmbedPrompt):
    def __init__(self, num_tasks, model, tokenizer, soft_prompt_dim, deep_prompt_dim, obs_embed_dim, shared_proj=False, hidden_layers=2, dropout_prob=0.0, lm_embed_init=False, clip_arch='ViT-B/32', clip_batch_size=8):
        super().__init__(num_tasks, model, soft_prompt_dim, deep_prompt_dim, obs_embed_dim, shared_proj=shared_proj, hidden_layers=hidden_layers, dropout_prob=dropout_prob, lm_embed_init=lm_embed_init)

        self.tokenizer = tokenizer# need to translate goals to clip form
        self.bilinear = nn.Bilinear(self.clip_hidden_dim, self.clip_hidden_dim, self.clip_hidden_dim)

    def _get_encode_clip(self, goal, goals_precomputed=None, compute_obs=True):
        if goals_precomputed is not None:
            goals_embeds = goals_precomputed
        else:
            goal_str = self.tokenizer.decode(goal)
            goal_ids = clip.tokenize(goal_str).to(self.clip_device)
            with torch.no_grad():
                goal_embeds = self.clip_model[0].encode_text(goal_ids).float()

        def _encode_clip(chunk):
            if compute_obs:
                with torch.no_grad():
                    chunk = chunk.to(self.clip_device)
                    chunk = self.clip_model[0].encode_image(chunk).float()

            return self.bilinear(chunk, goal_embeds.expand(chunk.shape))
        return _encode_clip

    def _encode_obs(self, task_ids, obs, goals, obs_precomputed=None, goals_precomputed=None):
        if obs_precomputed is not None:
            obs = obs_precomputed
        obs_spl = tree.map_structure(list, obs)# tree of 300 x 300 x 3
        if obs_precomputed is not None:
            obs_embeds = utils.map_structure_stacked(self._postprocess_imgs, obs_spl)
        else:
            obs_spl = tree.map_structure(self.clip_preprocess, obs_spl)
            obs_embeds = []
            compute_obs = obs_precomputed is None
            for i, (o, g) in enumerate(zip(obs_spl, goals)):
                g_p = goals_precomputed[i] if goals_precomputed is not None else None
                obs_embeds.append(utils.map_structure_batched(self._get_encode_clip(g, g_p, compute_obs), o, post_func=self._postprocess_imgs, batch_size=self.clip_batch_size))
        obs_embeds = tree.map_structure_up_to(obs, torch.cat, obs_embeds)# tree of (frames x obs_embed_dim) x LM hidden size
        return obs_embeds

