import time
import torch
import torch.nn as nn

from resnet import resnet18

from nerf_helpers import *



# class ImageEncoder(nn.Module):
class ImageEncoder_avgpool(nn.Module):

    def __init__(self, nf_out):
        super(ImageEncoder, self).__init__()

        self.encoder = resnet18()
        self.fc = nn.Linear(512, nf_out)

    def forward(self, imgs):
        # imgs: B x 3 x H x W
        B, C, H, W = imgs.size()
        assert H == 180
        assert W == 180
        x = self.encoder.conv1(imgs)
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)

        x = self.encoder.layer1(x)
        x = self.encoder.layer2(x)
        x = self.encoder.layer3(x)
        x = self.encoder.layer4(x)

        x = self.encoder.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


# class ImageEncoder_flatten(nn.Module):
class ImageEncoder(nn.Module):

    def __init__(self, nf_out):
        super(ImageEncoder, self).__init__()

        self.encoder = resnet18()
        # self.fc = nn.Linear(512, nf_out)

        self.fc = nn.Sequential(
            nn.Linear(18432, nf_out),
            nn.ReLU(),
            nn.Linear(nf_out, nf_out))

    def forward(self, imgs):
        # imgs: B x 3 x H x W
        B, C, H, W = imgs.size()
        assert H == 180
        assert W == 180
        x = self.encoder.conv1(imgs)
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)

        x = self.encoder.layer1(x)
        x = self.encoder.layer2(x)
        x = self.encoder.layer3(x)
        x = self.encoder.layer4(x)

        x = x.reshape(B, -1)
        x = self.fc(x)

        '''
        x = self.encoder.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        '''

        return x


class ImageDecoder(nn.Module):

    def __init__(self, args):
        super(ImageDecoder, self).__init__()

        self.H = 180
        self.W = 180
        nf = args.nf_hidden

        sequence = [
            nn.ConvTranspose2d(nf + 16, nf, 4, 1, 0),
            nn.ReLU(),
            nn.InstanceNorm2d(nf),
            nn.ConvTranspose2d(nf, nf, 4, 2, 0),
            nn.ReLU(),
            nn.InstanceNorm2d(nf),
            nn.ConvTranspose2d(nf, nf // 2, 3, 2, 0),
            nn.ReLU(),
            nn.InstanceNorm2d(nf // 2),
            nn.ConvTranspose2d(nf // 2, nf // 2, 4, 2, 0),
            nn.ReLU(),
            nn.InstanceNorm2d(nf // 2),
            nn.ConvTranspose2d(nf // 2, nf // 4, 3, 2, 0),
            nn.ReLU(),
            nn.InstanceNorm2d(nf // 4),
            nn.ConvTranspose2d(nf // 4, 3, 4, 2, 0)
        ]

        self.model = nn.Sequential(*sequence)

    def forward(self, img_embeds):

        B, nf = img_embeds.size()
        imgs = self.model(img_embeds.reshape(B, nf, 1, 1))

        return imgs




class Transformer(nn.Module):

    def __init__(self, nf_in=256, nf_hidden=256, dropout=0.2, activation='relu', nhead=2):
        super(Transformer, self).__init__()

        from torch.nn import TransformerEncoder, TransformerEncoderLayer

        self.embedder = nn.Sequential(
            nn.Linear(nf_in, nf_hidden),
            nn.ReLU(),
            nn.Linear(nf_hidden, nf_hidden),
            nn.ReLU())

        self.encoder = nn.Sequential(
            nn.Linear(nf_hidden, nf_hidden),
            nn.ReLU(),
            nn.Linear(nf_hidden, nf_hidden),
            nn.ReLU(),
            nn.Linear(nf_hidden, nf_hidden))

        '''
        encoder_layer = nn.modules.TransformerEncoderLayer(
            d_model=nf_hidden,
            nhead=nhead,
            dropout=dropout,
            dim_feedforward=nf_hidden)

        self.transformer = nn.modules.TransformerEncoder(
            encoder_layer,
            num_layers=6)
        '''

        # print('run transformer head ', nhead)
        # self._reset_parameters()

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, inputs, mask_nodes):
        # inputs: B x n_view x nf_in
        # mask_nodes: B x n_view (1 means valid, 0 means null)
        B, n_view, nf_in = inputs.size()

        # outputs: B x n_view x nf_hidden
        outputs = self.embedder(inputs.view(B * n_view, nf_in)).view(B, n_view, -1)

        '''
        outputs = self.transformer(
            outputs.transpose(0, 1),
            src_key_padding_mask=(1 - mask_nodes).bool())
        outputs = outputs.transpose(0, 1)
        '''

        outputs = torch.sum(outputs * mask_nodes[:, :, None], 1)
        outputs = outputs / torch.sum(mask_nodes, 1, keepdim=True)
        outputs = self.encoder(outputs)

        # outputs: B x nf_in
        return outputs



class LatentSpaceDynamics(nn.Module):

    def __init__(self, nf_in, nf_hidden, nf_out):
        super(LatentSpaceDynamics, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(nf_in, nf_hidden),
            nn.ReLU(),
            nn.Linear(nf_hidden, nf_hidden),
            nn.ReLU(),
            nn.Linear(nf_hidden, nf_hidden),
            nn.ReLU(),
            nn.Linear(nf_hidden, nf_out))

    def forward(self, inputs, actions):
        # inputs: B x n_his x nf_in
        # actions: B x n_his x act_dim
        B, n_his, in_feat = inputs.size()
        output = self.model(torch.cat([inputs, actions], -1).view(B, -1))

        output = output / (output.norm(dim=1) + 1e-8).view(B, -1)

        # output: B x in_feat
        return output



class Renderer(nn.Module):

    def __init__(self, args):
        super(Renderer, self).__init__()

        self.args = args
        self.N_rand = args.N_rand
        self.render_kwargs_train, self.render_kwargs_test = self.create_nerf(args)

    def create_nerf(self, args):
        """Instantiate NeRF's MLP model.
        """
        embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

        input_ch_views = 0
        embeddirs_fn = None
        if args.use_viewdirs:
            embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
        output_ch = 5 if args.N_importance > 0 else 4
        skips = [4]
        self.model = NeRF(
            D=args.netdepth, W=args.netwidth,
            input_ch=input_ch, output_ch=output_ch, skips=skips,
            input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)

        self.model_fine = None
        if args.N_importance > 0:
            self.model_fine = NeRF(
                D=args.netdepth_fine, W=args.netwidth_fine,
                input_ch=input_ch, output_ch=output_ch, skips=skips,
                input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)

        network_query_fn = lambda inputs, latents, viewdirs, network_fn : run_network(
            inputs, latents, viewdirs, network_fn,
            embed_fn=embed_fn,
            embeddirs_fn=embeddirs_fn,
            netchunk=args.netchunk)

        basedir = args.basedir
        expname = args.expname

        render_kwargs_train = {
            'network_query_fn' : network_query_fn,
            'perturb' : args.perturb,
            'N_importance' : args.N_importance,
            'network_fine' : self.model_fine,
            'N_samples' : args.N_samples,
            'network_fn' : self.model,
            'use_viewdirs' : args.use_viewdirs,
            'white_bkgd' : args.white_bkgd,
            'raw_noise_std' : args.raw_noise_std,
        }

        render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
        render_kwargs_test['perturb'] = False
        render_kwargs_test['raw_noise_std'] = 0.

        return render_kwargs_train, render_kwargs_test

    def render(self, H, W, focal, latents, chunk=1024*32, rays=None, near=0., far=1.,
               use_viewdirs=True, **kwargs):
        """Render rays
        Args:
            H: int. Height of image in pixels.
            W: int. Width of image in pixels.
            focal: float. Focal length of pinhole camera.
            chunk: int. Maximum number of rays to process simultaneously. Used to
                control maximum memory usage. Does not affect final results.
            rays: array of shape [2, batch_size, 3]. Ray origin and direction for
                each example in batch.
            near: float or array of shape [batch_size]. Nearest distance for a ray.
            far: float or array of shape [batch_size]. Farthest distance for a ray.
            use_viewdirs: bool. If True, use viewing direction of a point in space in model.
        Returns:
            rgb_map: [batch_size, 3]. Predicted RGB values for rays.
            disp_map: [batch_size]. Disparity map. Inverse of depth.
            acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
            extras: dict with everything returned by render_rays().
        """
        # use provided ray batch
        rays_o, rays_d = rays

        if use_viewdirs:
            # provide ray directions as input
            viewdirs = rays_d
            viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
            viewdirs = torch.reshape(viewdirs, [-1, 3]).float()

        sh = rays_d.shape # [..., 3]

        # Create ray batch
        rays_o = torch.reshape(rays_o, [-1, 3]).float()
        rays_d = torch.reshape(rays_d, [-1, 3]).float()

        near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
        rays = torch.cat([rays_o, rays_d, near, far], -1)
        if use_viewdirs:
            rays = torch.cat([rays, viewdirs], -1)

        # append latent states
        rays = torch.cat([rays, latents], -1)

        # Render and reshape
        all_ret = batchify_rays(rays, chunk, **kwargs)
        for k in all_ret:
            k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
            all_ret[k] = torch.reshape(all_ret[k], k_sh)

        k_extract = ['rgb_map', 'disp_map', 'acc_map']
        ret_list = [all_ret[k] for k in k_extract]
        ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}

        return ret_list + [ret_dict]

    def render_rays(self, latents, poses, camera_info, target):
        # latents: B x nf_hidden
        # poses: B x n_view x 3 x 4
        # target: B x n_view x 3 x H x W
        B, n_view, _, _ = poses.size()

        H, W, focal = camera_info['hwf']
        H, W = int(H.item()), int(W.item())
        near, far = camera_info['near'], camera_info['far']

        N_rand = self.N_rand

        '''
        print()
        print('latents.size()', latents.size())
        print('poses.size()', poses.size())
        print('target.size()', target.size())
        '''

        # rays_o: B x n_view x H x W x 3
        # rays_d: B x n_view x H x W x 3
        rays_o, rays_d = get_rays_batch(
            H, W, focal,
            torch.Tensor(poses.reshape(B * n_view, 3, 4)))
        rays_o = rays_o.view(B, n_view, H, W, 3)
        rays_d = rays_d.view(B, n_view, H, W, 3)


        ### sample rays from (B x n_view x H x W)
        # coords: B x n_view x H x W x 4
        select_inds = np.concatenate([
            np.random.choice(n_view * H * W, size=[N_rand], replace=False) for i in range(B)])
        batch_idx = np.repeat(np.arange(B), N_rand)
        coords = np.stack([batch_idx, select_inds], -1)

        # batch_rays: 2 x (B * N_rand) x 3
        rays_o = rays_o.reshape(B, n_view * H * W, 3)[coords[:, 0], coords[:, 1]]
        rays_d = rays_d.reshape(B, n_view * H * W, 3)[coords[:, 0], coords[:, 1]]
        batch_rays = torch.stack([rays_o, rays_d], 0)

        # latents: (B * N_rand) x nf_hidden
        latents = latents[:, None, :].repeat(1, N_rand, 1).reshape(B * N_rand, -1)

        ### rendering
        rgb, disp, acc, extras = self.render(
            H, W, focal, latents=latents, chunk=self.args.chunk, rays=batch_rays,
            retraw=True, **self.render_kwargs_train)

        # rgb: N_rand x 3
        # target: N_rand x 3
        target = target.permute(0, 1, 3, 4, 2)
        target_s = target.reshape(B, n_view * H * W, 3)[coords[:, 0], coords[:, 1]]  # (N_rand, 3)
        return rgb, extras, target_s

    def render_imgs(self, latents, poses, camera_info):
        # latents: B x nf_hidden
        # poses: B x n_view x 3 x 4
        B, n_view, _, _ = poses.size()

        H, W, focal = camera_info['hwf']
        H, W = int(H.item()), int(W.item())
        near, far = camera_info['near'], camera_info['far']

        # rays_o: B x n_view x H x W x 3
        # rays_d: B x n_view x H x W x 3
        rays_o, rays_d = get_rays_batch(
            H, W, focal,
            torch.Tensor(poses.reshape(B * n_view, 3, 4)))
        rays_o = rays_o.reshape(B * n_view * H * W, 3)
        rays_d = rays_d.reshape(B * n_view * H * W, 3)
        batch_rays = torch.stack([rays_o, rays_d], 0)

        latents = latents[:, None, :].repeat(1, n_view * H * W, 1).reshape(B * n_view * H * W, -1)

        ### rendering
        rgb, disp, acc, extras = self.render(
            H, W, focal, latents=latents, chunk=self.args.chunk, rays=batch_rays,
            retraw=True,  **self.render_kwargs_train)

        rgb = rgb.reshape(B, n_view, H, W, 3)

        return rgb, extras



class DynamicsModel(nn.Module):

    def __init__(self, args):
        super(DynamicsModel, self).__init__()

        self.args = args

        self.latent_space_dynamics = LatentSpaceDynamics(
            nf_in=(args.nf_hidden + args.act_dim) * args.n_his,
            nf_hidden=args.nf_hidden,
            nf_out=args.nf_hidden)

        self.img_encoder = ImageEncoder(
            nf_out=args.nf_hidden)

        self.tf_encoder = Transformer(
            nf_in=args.nf_hidden + 16,
            nf_hidden=args.nf_hidden)

        if args.auto_loss == 0:
            self.decoder = Renderer(args)

        elif args.auto_loss == 1:
            self.decoder = ImageDecoder(args)

    def encode_img(self, imgs):
        # imgs: B x N x n_view x 3 x H x W
        B, N, n_view, C, H, W = imgs.size()

        ret = self.img_encoder(imgs.reshape(B * N * n_view, C, H, W))
        ret = ret.view(B, N, n_view, -1)

        # ret: B x N x n_view x nf_hidden
        return ret

    def encode_state(self, img_embeds, poses, mask_nodes):
        # img_embeds: B x N x n_view x nf_hidden
        # poses: B x N x n_view x 4 x 4
        B, N, n_view, nf_in = img_embeds.size()

        img_embeds = torch.cat([img_embeds, poses.view(B, N, n_view, -1)], -1)

        # ret: (B * N) x nf_hidden
        ret = self.tf_encoder(
            img_embeds.view(B * N, n_view, -1),
            mask_nodes.view(B * N, n_view))

        ## add norm
        ret = ret / (ret.norm(dim=1) + 1e-8).view(B * N, -1)

        ret = ret.view(B, N, -1)

        # ret: B x N x nf_out
        return ret

    def decode_img(self, state_embeds, camera_info):
        # state_embeds: B x N x nf_hidden
        # poses: B x N x n_view x 4 x 4
        poses = camera_info['poses']
        B, N, n_view, _, _ = poses.size()

        # print(poses.size())
        # print(state_embeds.size())

        poses = poses.view(B * N, n_view, -1)
        img_embeds = state_embeds.view(B * N, -1)[:, None, :].repeat(1, n_view, 1)
        img_embeds = torch.cat([img_embeds, poses], 2).view(B * N * n_view, -1)

        assert img_embeds.size() == (B * N * n_view, state_embeds.size(-1) + 16)

        imgs = self.decoder(img_embeds)
        _, C, H, W = imgs.size()
        return imgs.view(B, N, n_view, C, H, W)

    def render_rays(self, state_embeds, camera_info, target_imgs):
        # state_embeds: B x N x nf_hidden
        # camera_info:
        # target_imgs: B x N x n_view x 3 x H x W
        B, N, n_view, C, H, W = target_imgs.size()

        # rgb: N_rand x 3
        # target_s: N_rand x 3
        poses = camera_info['poses'].view(B * N, n_view, 4, 4)
        rgb, extras, target_s = self.decoder.render_rays(
            state_embeds.view(B * N, -1),
            poses[:, :, :3, :4],
            camera_info,
            target_imgs.view(B * N, n_view, C, H, W))
        return rgb, extras, target_s

    def render_imgs(self, state_embeds, camera_info):
        # state_embeds: B x nf_hidden
        # poses: B x n_view x 4 x 4
        poses = camera_info['poses']

        # rgb: B x n_view x H x W x 3
        rgb, extras = self.decoder.render_imgs(
            state_embeds,
            poses[:, :, :3, :4],
            camera_info)
        return rgb, extras

    def dynamics_prediction(self, state_cur, action_cur):
        # state_cur: B x n_his x nf_in
        # action_cur: B x n_his x act_dim
        return self.latent_space_dynamics(state_cur, action_cur)

