import torch
import torch.nn as nn
import numpy as np
import math
import pickle
try:
    from .transformer import Transformer1D
except:
    from transformer import Transformer1D

from easydict import EasyDict as edict

class DeconvBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=2, stride=1, upsample=None):
        super(DeconvBottleneck, self).__init__()
        self.expansion = expansion
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        if stride == 1:
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                                   stride=stride, bias=False, padding=1)
        else:
            self.conv2 = nn.ConvTranspose2d(out_channels, out_channels,
                                            kernel_size=3,
                                            stride=stride, bias=False,
                                            padding=1,
                                            output_padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=False)
        self.upsample = upsample

    def forward(self, x):
        shortcut = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu(out)

        if self.upsample is not None:
            shortcut = self.upsample(x)

        out = out+shortcut
        out = self.relu(out)

        return out


class FourierFeatureTransform(nn.Module):
    def __init__(self, num_input_channels, mapping_size, scale=10):
        super().__init__()

        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = nn.Parameter(torch.randn((num_input_channels, mapping_size)) * scale, requires_grad=False)

    def forward(self, x):
        # B, N, C = x.shape
        # x = (x.reshape(B*N, C) @ self._B).reshape(B, N, -1)
        x = x@self._B
        x = 2 * np.pi * x
        # print(x.size())
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
    
class SingleTriplane(nn.Module):
    def __init__(self, output_dim=3,channel_dim=512, width=32, height=32):
        super().__init__()
        
        # self.embeddings = nn.Parameter(torch.randn(1, channel_dim, width, height)*0.001)
        self.embeddings = nn.Parameter(
            torch.randn(
                (3, channel_dim, width, height),
                dtype=torch.float32,
            )
            * 1
            / math.sqrt(channel_dim)
        )
                
        cfg = edict()
        cfg.in_channels = channel_dim
        cfg.num_attention_heads = 4
        cfg.attention_head_dim = 32
        cfg.num_layers = 2
        cfg.cross_attention_dim=768 # hard-code, =DINO feature dim
        cfg.norm_type = "layer_norm"
        cfg.enable_memory_efficient_attention = False
        cfg.gradient_checkpointing = False
        self.proj_dino = Transformer1D(cfg)

        # cfg = edict()
        cfg.cross_attention_dim=512 # hard-code, =Arcface feature dim
        self.proj_arc = Transformer1D(cfg)

        # Use this if you want a PE
        # self.net = nn.Sequential(
        #     FourierFeatureTransform(channel_dim, 256, scale=1),
        #     nn.Linear(512, 512),
        #     nn.ReLU(inplace=True),
            
        #     nn.Linear(512, 512),
        #     nn.ReLU(inplace=True),
            
        #     nn.Linear(512, output_dim),
        # )

        self.in_channels = channel_dim
        self.uplayer1 = self._make_up_block(DeconvBottleneck, 128, 1, stride=2) #64
        self.uplayer2 = self._make_up_block(DeconvBottleneck, 64, 1, stride=2) #128
        self.uplayer3 = self._make_up_block(DeconvBottleneck, 32, 1, stride=2) #256
        # self.uplayer4 = self._make_up_block(DeconvBottleneck, 32, 1, stride=2) #512


        self.net = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(inplace=False),
            nn.Linear(128, 256),
            nn.ReLU(inplace=False),
            nn.Linear(256, output_dim),
        )

    
    def _make_up_block(self, block, init_channels, num_layer, stride=1):
        upsample = None
        # expansion = block.expansion
        if stride != 1 or self.in_channels != init_channels * 2:
            upsample = nn.Sequential(
                nn.ConvTranspose2d(self.in_channels, init_channels*2,
                                   kernel_size=1, stride=stride,
                                   bias=False, output_padding=1),
                nn.BatchNorm2d(init_channels*2),
            )
        layers = []
        for i in range(1, num_layer):
            layers.append(block(self.in_channels, init_channels, 4))
        layers.append(block(self.in_channels, init_channels, 2, stride, upsample))
        self.in_channels = init_channels * 2
        return nn.Sequential(*layers)


    def sample_plane(self, coords2d, plane):
        assert len(coords2d.shape) == 3, coords2d.shape
        # print(coords2d.size())
        sampled_features = torch.nn.functional.grid_sample(plane,
                                                           coords2d.reshape(coords2d.shape[0], 1, -1, coords2d.shape[-1]),
                                                           mode='bilinear', padding_mode='zeros', align_corners=True)
        N, C, H, W = sampled_features.shape
        sampled_features = sampled_features.reshape(N, C, H*W).permute(0, 2, 1)
        return sampled_features

    def forward(self, coordinates,face_feat,dino_feat):
        # print(torch.sum(self.embeddings))
        coordinates = coordinates.unsqueeze(0) #*1.5 #4096*3
        N,C,H,W = self.embeddings.size()  #(3,512,32,32)
        triplane = self.embeddings.permute(1,0,2,3).contiguous() #(512,3,32,32)
        triplane = triplane.reshape(1,triplane.size(0),-1) #(1,512,3*32*32)
        assert len(triplane)==len(face_feat)==len(dino_feat)
        # print(triplane.size(),dino_feat.size())
        triplane = self.proj_dino(triplane,dino_feat,modulation_cond=None,)
        # print(triplane.size(),face_feat.size())
        triplane = self.proj_arc(triplane,face_feat,modulation_cond=None,)
        
        temp = triplane.reshape(C,N*H,W).unsqueeze(0)  #(1, 512,3*32,32)
        # print(triplane.size())
        temp = self.uplayer1(temp) #(1, 256,3*32*2,32*2)
        # print(triplane.size())
        temp = self.uplayer2(temp) #(1, 128,3*32*2*2,32*2*2)
        # print(triplane.size())
        temp = self.uplayer3(temp) #(1, 64, 3*32*2*2*2,32*2*2*2)
        # print(triplane.size())
        # temp = self.uplayer4(temp)
        # print(triplane.size())
        triplane = temp.squeeze(0).reshape(-1,N,H*8,W*8).permute(1,0,2,3).contiguous()
        # print(triplane.size())

        # batch_size, n_coords, n_dims = coordinates.shape
        # print(coordinates.size())
        xy_embed = self.sample_plane(coordinates[..., 0:2],  triplane[0].unsqueeze(0))
        yz_embed = self.sample_plane(coordinates[..., 1:3],  triplane[1].unsqueeze(0))
        xz_embed = self.sample_plane(coordinates[..., :3:2], triplane[2].unsqueeze(0))

        # print(xy_embed.size())
        features = torch.sum(torch.stack([xy_embed, yz_embed, xz_embed]), dim=0) 
        # print(features.size())
        features = features.squeeze(0)
        pred =  self.net(features)
        # print(pred.size())
        # pred = (torch.sigmoid(pred)*2-1)/1.5
        return pred


    def tvreg(self):
        l = 0
        for embed in self.embeddings:
            l += ((embed[:, :, 1:] - embed[:, :, :-1])**2).sum()**0.5
            l += ((embed[:, :, :, 1:] - embed[:, :, :, :-1])**2).sum()**0.5
        return l
   



def project_onto_planes(planes, coordinates):
    """
    Does a projection of a 3D point onto a batch of 2D planes,
    returning 2D plane coordinates.

    Takes plane axes of shape n_planes, 3, 3
    # Takes coordinates of shape N, M, 3
    # returns projections of shape N*n_planes, M, 2
    """
    N, M, C = coordinates.shape
    n_planes, _, _ = planes.shape
    coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
    inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
    projections = torch.bmm(coordinates, inv_planes)
    return projections

from torchvision.ops import DeformConv2d
class DConv(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False):
        super(DConv, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, 2 * kernel_size * kernel_size, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=bias)
        self.conv2 = DeformConv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(x, out)
        return out
    





class SingleTriplanev2(nn.Module):
    def __init__(self, output_dim=3,channel_dim=36, width=512, height=512):
        super().__init__()
        
        # load pretrained triplane
        with open('/home/jy496/work/Portrait3D/test_data_head/1/samples_new_crop/final_inversion/0000_0_seg0/inversion_trigrid.pkl', 'rb') as file:
            data = pickle.load(file)
        plane_feature = data['trigrids_512']
        # self.plane_feature = self.plane_feature [0,[1,2,0],...].unsqueeze(0)
        plane_feature = plane_feature [0,[2,1,0],...].unsqueeze(0)
        # plane_feature = torch.randn_like(plane_feature)
        self.embeddings = torch.nn.Parameter(torch.tensor(plane_feature, dtype=torch.float32))
        



        # self.embeddings.requires_grad = True
        
        self.planes = torch.tensor([[[1, 0, 0],
                                     [0, 1, 0],
                                     [0, 0, 1]],
                                     [[1, 0, 0],
                                      [0, 0, 1],
                                      [0, 1, 0]],
                                     [[0, 1, 0],
                                      [0, 0, 1],
                                      [1, 0, 0]]], dtype=torch.float32)
        
        self.deform = 'affine' #['stn','dconv','affine']
        if self.deform=='stn':
            print('use stn')
            aa = torch.tensor([1.21346186e+00,4.85098221e-03,-4.01300093e+01/256, -4.85098221e-03,1.21346186e+00,-2.30892305e+01/256], dtype=torch.float)
            self.localization_1 = nn.Sequential(
                nn.Conv2d(1, 8, kernel_size=7),
                nn.MaxPool2d(2, stride=2),
                nn.ReLU(True),
                nn.Conv2d(8, 10, kernel_size=5),
                nn.MaxPool2d(2, stride=2),
                nn.ReLU(True)
                )
            # Regressor for the 3 * 2 affine matrix
            self.fc_loc_1 = nn.Sequential(
                nn.Linear(10 * 3 * 3, 32),
                nn.ReLU(True),
                nn.Linear(32, 3 * 2)
                )
            self.fc_loc_1[2].weight.data.zero_()
            self.fc_loc_1[2].bias.data.copy_(aa)

            self.localization_2 = nn.Sequential(
                nn.Conv2d(1, 8, kernel_size=7),
                nn.MaxPool2d(2, stride=2),
                nn.ReLU(True),
                nn.Conv2d(8, 10, kernel_size=5),
                nn.MaxPool2d(2, stride=2),
                nn.ReLU(True)
                )
            self.fc_loc_2 = nn.Sequential(
                nn.Linear(10 * 3 * 3, 32),
                nn.ReLU(True),
                nn.Linear(32, 3 * 2)
                )
            self.fc_loc_2[2].weight.data.zero_()
            self.fc_loc_2[2].bias.data.copy_(aa)

            self.localization_3 = nn.Sequential(
                nn.Conv2d(1, 8, kernel_size=7),
                nn.MaxPool2d(2, stride=2),
                nn.ReLU(True),
                nn.Conv2d(8, 10, kernel_size=5),
                nn.MaxPool2d(2, stride=2),
                nn.ReLU(True)
                )
            self.fc_loc_3 = nn.Sequential(
                nn.Linear(10 * 3 * 3, 32),
                nn.ReLU(True),
                nn.Linear(32, 3 * 2)
                )
            self.fc_loc_3[2].weight.data.zero_()
            # self.fc_loc_3[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
            self.fc_loc_3[2].bias.data.copy_(aa)

                
        elif self.deform=='deconv':
            print('use deform conv')
            self.conv1 = DConv(36,36)
            self.conv2 = DConv(36,36)
            self.conv3 = DConv(36,36)

        else:
            print('use affine deform from predefined params')


        #predict new xyz
        # self.net = nn.Sequential(
        #             nn.Linear(channel_dim, 16),
        #             torch.nn.Softplus(),
        #             # torch.nn.GELU(),
        #             nn.Linear(16, 16),
        #             # torch.nn.GELU(),
        #             torch.nn.Softplus(),
        #             nn.Linear(16, output_dim))
        
        self.net = nn.Sequential(
            nn.Linear(channel_dim, channel_dim),
            torch.nn.Softplus(),
            nn.Linear(channel_dim, output_dim),
            nn.Sigmoid())

        # self.net = nn.Sequential(
        #     nn.Linear(channel_dim, channel_dim),
        #     torch.nn.Softplus(),
        #     nn.Linear(channel_dim, output_dim))

    def stn(self, x):
        x = x.squeeze(0)
        x_ = x.mean(dim=1).unsqueeze(1)
        x_ = torch.nn.functional.interpolate(x_,(28,28))
        x_ = torch.chunk(x_,3)
        x_raw = torch.chunk(x,3)
        # print(x_.size()) #3,1,28,28

        xs = self.localization_1(x_[0])
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc_1(xs)
        theta = theta.view(-1, 2, 3)
        grid = torch.nn.functional.affine_grid(theta, x_raw[0].size(),align_corners=True)
        x_1 = torch.nn.functional.grid_sample(x_raw[0], grid,align_corners=True)
        if x_1.sum()==0:
            print('x_1 triplane error............',theta)
            raise NotImplementedError

        xs = self.localization_2(x_[1])
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc_2(xs)
        theta = theta.view(-1, 2, 3)
        grid = torch.nn.functional.affine_grid(theta, x_raw[1].size(),align_corners=True)
        x_2 = torch.nn.functional.grid_sample(x_raw[1], grid,align_corners=True)
        if x_2.sum()==0:
            print('x_2 triplane error............',theta)
            raise NotImplementedError

        xs = self.localization_3(x_[2])
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc_3(xs)
        theta = theta.view(-1, 2, 3)
        grid = torch.nn.functional.affine_grid(theta, x_raw[2].size(),align_corners=True)
        x_3 = torch.nn.functional.grid_sample(x_raw[2], grid,align_corners=True)
        if x_3.sum()==0:
            print('x_3 triplane error............',theta)
            raise NotImplementedError
        x = torch.concatenate([x_1,x_2,x_3],dim=0)
        # print(x.size())
        x = x.unsqueeze(0)
        return  x
    
    def dconv(self,x):
        a = self.conv1(x[0][0].unsqueeze(0))
        b = self.conv2(x[0][1].unsqueeze(0))
        c = self.conv3(x[0][2].unsqueeze(0))
        x = torch.cat([a,b,c],dim=0).unsqueeze(0)
        return x
    
    def affine(self,x):
        x_ = x.squeeze(0)
        theta = torch.tensor([[1.21346186e+00,4.85098221e-03,-4.01300093e+01/256],
                              [-4.85098221e-03,1.21346186e+00,-2.30892305e+01/256]], dtype=torch.float)
        theta = theta.unsqueeze(0).expand(len(x_),-1,-1).to(x.device)
        grid = torch.nn.functional.affine_grid(theta, x_.size(),align_corners=True)
        # print(grid)
        x = torch.nn.functional.grid_sample(x_, grid,align_corners=True)
        return x.unsqueeze(0)
    
    def forward(self, coordinates,save_tri=False):
        assert coordinates.max() <= 1 and coordinates.min() >= -1, f"coordinates must be in [-1, 1], got {coordinates.min()} and {coordinates.max()}"
        coordinates = coordinates.unsqueeze(0) 
        _, M, _ = coordinates.shape
        N, n_planes, CD, H, W = self.embeddings.shape
        assert N ==1 
        triplane_depth = 1
        C, D = CD // triplane_depth, triplane_depth

        if self.deform == 'dconv':
            triplane = self.dconv(self.embeddings)
        elif self.deform == 'stn':
            # print(self.embeddings.requires_grad)
            # print(self.embeddings.grad)
            triplane = self.stn(self.embeddings)
        else:
            triplane = self.affine(self.embeddings)
        
        if triplane.sum()==0:
            print('triplane error............')
            raise NotImplementedError
        
        if save_tri:
            print('save triplane')
            torch.save(triplane,'new_triplane.pt')

        plane_feature = triplane.view(N * n_planes, C, D, H, W)
        projected_coordinates = project_onto_planes(self.planes.to(coordinates.device),coordinates=coordinates).unsqueeze(1).unsqueeze(2)

        triplane_depth = 1
        mode='bilinear'
        padding_mode='zeros'
        print(plane_feature.size())
        print(projected_coordinates.size())
        output_feature = torch.nn.functional.grid_sample(plane_feature, projected_coordinates.float(), mode=mode,
                                                        padding_mode=padding_mode, align_corners=False).permute(0,
                                                                                                                4,
                                                                                                                3,
                                                                                                                2,
                                                                                                                1).reshape(N, n_planes, M, C)

        output_feature = output_feature.mean(1)
        output_feature = output_feature.reshape(N*M,C)
        pred =  self.net(output_feature)
        pred = pred*2-1
        pred = pred+coordinates.squeeze(0)
        return pred


class SingleTriplanev3(nn.Module):
    def __init__(self, output_dim=3,channel_dim=36, width=512, height=512):
        super().__init__()
        data = torch.load('/home/jy496/work/threestudio/load/face/tri_save.pkl','cpu')
        plane_feature = data.unsqueeze(0).unsqueeze(2)
        self.embeddings = torch.nn.Parameter(torch.tensor(plane_feature, dtype=torch.float32))
        self.net = nn.Sequential(
            nn.Linear(channel_dim, channel_dim),
            torch.nn.Softplus(),
            nn.Linear(channel_dim, channel_dim),
           )
        # self.net[-1].weight.data.zero_()
        # self.net[-1].bias.data.zero_()
        # nn.init.kaiming_uniform_(self.net[-1].weight, nonlinearity='relu')
        # nn.init.zeros_(self.net[-1].bias)

        # nn.init.uniform_(self.net[-1].weight, a=-0.001, b=0.001)
        # nn.init.uniform_(self.net[-1].bias, a=-0.001, b=0.001)


        self.opacity = nn.Sequential(nn.Linear(channel_dim, 1), nn.Sigmoid())
        self.shs = nn.Linear(channel_dim, 3)
        
        #  nn.Sigmoid()
        # nn.Linear(channel_dim, output_dim) zeros

    def forward(self, coordinates,save_tri=False):
        yzx = coordinates[:,[1,2,0]]
        triplane = self.embeddings 
        mode='bilinear'
        padding_mode='zeros'
        # print(triplane.size())
        # print(yzx.size())
        output_feature = torch.nn.functional.grid_sample(triplane, yzx.float().unsqueeze(0).unsqueeze(1).unsqueeze(2), mode=mode, padding_mode=padding_mode, align_corners=False)
        output_feature = output_feature.permute(0,4, 3, 2,1).reshape(len(yzx), triplane.size(1))
        fea = self.net(output_feature)
        shs = self.shs(fea)
        opacity = self.opacity(fea)
        return {'shs': shs.unsqueeze(1), 'opacity': opacity} # add scale rotation delta_xyz
        
        # print(pred[0])
        # pred = pred*2-1        
        # pred = pred+coordinates
        # return pred
        
