import torch
from torch import nn
import torch.nn.functional as F
import math
import numpy as np
import torch.autograd as autograd
from torch import Tensor, index_select, nn
from utils.quantizer import quantize
from utils.arm import (
    Arm,
    _get_neighbor,
    _get_non_zero_pixel_ctx_index,
    _laplace_cdf,
)
from utils.upsampling import Upsampling
from enc.misc import (
    MAX_ARM_MASK_SIZE,
    POSSIBLE_DEVICE,
    DescriptorCoolChic,
    DescriptorNN,
    measure_expgolomb_rate,
)
from typing import Any, Dict, List, Optional, OrderedDict

class SynthesisLayer(nn.Module):
    def __init__(
        self,
        input_ft: int,
        output_ft: int,
        kernel_size: int,
        non_linearity: nn.Module = nn.Identity()
    ):
        
        super().__init__()

        self.pad = nn.ReplicationPad2d(int((kernel_size - 1) / 2))
        self.conv_layer = nn.Conv2d(
            input_ft,
            output_ft,
            kernel_size
        )
        self.non_linearity = non_linearity
        with torch.no_grad():
            self.conv_layer.weight.data = self.conv_layer.weight.data / output_ft ** 2
            self.conv_layer.bias.data = self.conv_layer.bias.data * 0.

    def forward(self, x: Tensor) -> Tensor:
        return self.non_linearity(self.conv_layer(self.pad(x)))

class SynthesisResidualLayer(nn.Module):
    def __init__(
        self,
        input_ft: int,
        output_ft: int,
        kernel_size: int,
        non_linearity: nn.Module = nn.Identity()
    ):
        
        super().__init__()

        assert input_ft == output_ft,\
            f'Residual layer in/out dim must match. Input = {input_ft}, output = {output_ft}'

        self.pad = nn.ReplicationPad2d(int((kernel_size - 1) / 2))
        self.conv_layer = nn.Conv2d(
            input_ft,
            output_ft,
            kernel_size
        )

        self.non_linearity = non_linearity

       
        with torch.no_grad():
            self.conv_layer.weight.data = self.conv_layer.weight.data * 0.
            self.conv_layer.bias.data = self.conv_layer.bias.data * 0.

    def forward(self, x: Tensor) -> Tensor:
        return self.non_linearity(self.conv_layer(self.pad(x)) + x)

class LocallyConnectedBlock(nn.Module):
    def __init__(self, in_channels, global_hid_channels, local_hid_channels, out_channels, mod_layer):
        super().__init__()
        self.net = []
        self.net.append(nn.Sequential(
           SynthesisLayer(2,local_hid_channels,1,nn.GELU())
        ))
        self.net.append(nn.Sequential(
           SynthesisResidualLayer(local_hid_channels,local_hid_channels,1,nn.GELU())
        ))
        self.net.append(nn.Sequential(
           SynthesisResidualLayer(local_hid_channels,local_hid_channels,3,nn.GELU())
        ))
        self.net.append(nn.Sequential(
           SynthesisResidualLayer(local_hid_channels,3,3)
        ))

        self.net = nn.Sequential(*self.net)
      
    def get_param(self) -> OrderedDict[str, Tensor]:
        return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
    def set_param(self, param: OrderedDict[str, Tensor]) -> None:
        self.load_state_dict(param)
    def forward(self, x):
        output_local = self.net(x)
        return output_local
    
class LocalGlobalBlock(LocallyConnectedBlock):
    def __init__(self, in_channels, global_hid_channels,local_hid_channels, out_channels, mod_layer,mask):
        super().__init__(in_channels, global_hid_channels, local_hid_channels, out_channels, mod_layer)    
        self.mask = mask
        self.agg_func = []
        self.agg_func.append(nn.Sequential(SynthesisLayer(global_hid_channels+3,3,1,nn.GELU())))
        self.agg_func = nn.Sequential(*self.agg_func)
        self.full_net = []
        self.full_net.append(SynthesisLayer(in_channels,global_hid_channels,1,nn.GELU()))
        self.full_net.append(SynthesisResidualLayer(global_hid_channels,global_hid_channels,1,nn.GELU()))
        self.full_net.append(SynthesisResidualLayer(global_hid_channels,global_hid_channels,1,nn.GELU()))
        self.full_net = nn.Sequential(*self.full_net)
    def get_param(self) -> OrderedDict[str, Tensor]:
        return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
    def set_param(self, param: OrderedDict[str, Tensor]) -> None:
        self.load_state_dict(param)
    def forward(self, coordinate, x):
        coordinate = coordinate
        self.mask = self.mask.bool()
        combined_latent = x
        object_latent = torch.zeros_like(coordinate) 
        object_latent[self.mask.expand_as(coordinate)] = coordinate[self.mask.expand_as(coordinate)]
        local_layer_input = object_latent
        full_layer_input = combined_latent
        id=0
        for local_layer, full_layer in zip(self.net,self.full_net):
            local_layer_input = local_layer(local_layer_input)
            if id<3:
                full_layer_input = full_layer(full_layer_input)
                local_layer_input = self.agg_func(torch.cat([local_layer_input,full_layer_input],dim=-3))
            id+=1
        local_layer_input = self.net[-1](local_layer_input)
        return local_layer_input




class MoRIC(nn.Module):
    def __init__(self, args,target_mask):
        super().__init__()
        self.net = []
        self.h = target_mask.shape[-2]
        self.w = target_mask.shape[-1]
        self.target_mask = target_mask
        self.upsampling_2d = Upsampling(args.upsampling_kernel_size, False,1)
        self.dim_arm=args.dim_arm_mod
        self.n_hidden_layers_arm=2
        self.arm = Arm(args.context_arm,args.dim_arm_mod, self.n_hidden_layers_arm)
        self.quantizer_type="softround"
        self.quantizer_noise_type="kumaraswamy"
        self.soft_round_temperature=0.3
        self.noise_parameter=2.0

        self.modulation_base_number=7
        
        self.fact_shape=[]
        for i in range (self.modulation_base_number):
            self.fact_shape.append((self.h//(2**i),self.w//(2**i)))
    
        self.fact_shape.reverse()
        self.mask_size=9
        self.encoder_gains_sf=16
        print('Quantizer parameter: encoding gain ',self.encoder_gains_sf)
        self.all_pix_num=self.h*self.w
        print('total pixel:',self.all_pix_num)

        self.register_buffer(
            "non_zero_pixel_ctx_index",
            _get_non_zero_pixel_ctx_index(args.context_arm),
            persistent=False,
        )
       
        self.latent_factor=1

        self.conv_mod = LocalGlobalBlock(in_channels=self.modulation_base_number, global_hid_channels=args.sythesis_features,local_hid_channels=3, out_channels=2+1, mod_layer=0, mask=self.target_mask)

        self.modules_to_send=['arm','conv_mod','upsampling_2d']

        self.nn_q_step: Dict[str, DescriptorNN] = {
            k: {"weight": None, "bias": None} for k in self.modules_to_send
        }
        self.nn_expgol_cnt: Dict[str, DescriptorNN] = {
            k: {"weight": None, "bias": None} for k in self.modules_to_send
        }
        self.modulation_sf= nn.ParameterList()

       
        self.mask_sf = []
        for layer_idx in range(self.modulation_base_number):
            mod_shape=self.fact_shape[layer_idx]
            shits =  nn.Parameter(torch.zeros(1,1,  mod_shape[0], mod_shape[1])).cuda()#.requires_grad=True
            if layer_idx>0:
                masks = F.max_pool2d(target_mask.float(), kernel_size=2)
            else:
                masks = target_mask
            target_mask = masks
            self.mask_sf.append(masks.cuda())
            self.modulation_sf.append(shits)
            print('Get Mod with shape',shits.shape,'at layer:',layer_idx+1)
       
        
    def quantize_all_latent(self,latent,coords):
    
        q_shifts_all_for_conv=[]
       
        for id in range(len(latent)):
           
            q_shifts_id = quantize(
                            latent[id] * self.encoder_gains_sf,
                            self.quantizer_noise_type if self.training else "none",
                            self.quantizer_type if self.training else "hardround",
                            self.soft_round_temperature,
                            self.noise_parameter,)

            q_shifts_all_for_conv.append(q_shifts_id*self.mask_sf[len(latent)-id-1])
       
        q_upsample_conv_o=(self.upsampling_2d(q_shifts_all_for_conv, self.mask_sf))
      
    
        weight_shift_all=self.conv_mod(coords, q_upsample_conv_o)
   
     
        return q_shifts_all_for_conv,weight_shift_all
        
    def estimate_rate(self, decoder_side_latent,arm_model):
        
        flat_context = torch.cat(
            [
                _get_neighbor(spatial_latent_i, self.mask_size, self.non_zero_pixel_ctx_index)[self.mask_sf[len(decoder_side_latent)-i-1].flatten().bool(),:]
                for i,spatial_latent_i in enumerate(decoder_side_latent)
            ],
            dim=0,
        )
       
        flat_latent = torch.cat(
            [spatial_latent_i.view(-1)[self.mask_sf[len(decoder_side_latent)-i-1].flatten().bool()] for i,spatial_latent_i in enumerate(decoder_side_latent)],
            dim=0
        )
      
        flat_context_in=flat_context.unsqueeze(0).transpose(1, 2)
       
        flat_mu, flat_scale, flat_log_scale__ = arm_model(flat_context_in)
        proba = torch.clamp_min(
            _laplace_cdf(flat_latent + 0.5, flat_mu, flat_scale)
            - _laplace_cdf(flat_latent - 0.5, flat_mu, flat_scale),
            min=2**-16, 
        )
        flat_rate = -torch.log2(proba)
        return flat_rate
    def get_network_rate(self):
       
        rate_per_module: DescriptorCoolChic = {
            module_name: {"weight": 0.0, "bias": 0.0}
            for module_name in self.modules_to_send
        }

        for module_name in self.modules_to_send:
            cur_module = getattr(self, module_name)
            rate_per_module[module_name] = measure_expgolomb_rate(
                cur_module,
                self.nn_q_step.get(module_name),
                self.nn_expgol_cnt.get(module_name),
            )
        return rate_per_module

    
    def forward(self, coords):
        input_=coords  
        q_shifts_all_viewed,input_=self.quantize_all_latent(self.modulation_sf,coords)
        flat_rate= self.estimate_rate(q_shifts_all_viewed,self.arm)
        batch_size = input_.shape[0]
        input_ = input_.view(batch_size,3, -1)[:,:,self.mask_sf[0].flatten().bool()]
        return input_.permute(0,2,1),flat_rate