import torch
from torch import nn
import torch.nn.functional as F
import math
import numpy as np
import torch.autograd as autograd
from torch import Tensor, nn
from utils.quantizer import quantize
from utils.arm import (
    Arm,
    _get_neighbor,
    _get_non_zero_pixel_ctx_index,
    _laplace_cdf,
)
from utils.upsampling import Upsampling
from enc.utils.misc import (
    MAX_ARM_MASK_SIZE,
    POSSIBLE_DEVICE,
    DescriptorCoolChic,
    DescriptorNN,
    measure_expgolomb_rate,
)
from typing import Dict, OrderedDict
from itertools import islice

class SynthesisLayer(nn.Module):
    def __init__(
        self,
        input_ft: int,
        output_ft: int,
        kernel_size: int,
        non_linearity: nn.Module = nn.Identity()
    ):
       
        super().__init__()

        self.pad = nn.ReplicationPad2d(int((kernel_size - 1) / 2))
        self.conv_layer = nn.Conv2d(
            input_ft,
            output_ft,
            kernel_size
        )

        self.non_linearity = non_linearity

       
        with torch.no_grad():
            self.conv_layer.weight.data = self.conv_layer.weight.data / output_ft ** 2
            self.conv_layer.bias.data = self.conv_layer.bias.data * 0.

    def forward(self, x: Tensor) -> Tensor:
        return self.non_linearity(self.conv_layer(self.pad(x)))

class SynthesisResidualLayer(nn.Module):
    def __init__(
        self,
        input_ft: int,
        output_ft: int,
        kernel_size: int,
        non_linearity: nn.Module = nn.Identity()
    ):
       
        super().__init__()

        assert input_ft == output_ft,\
            f'Residual layer in/out dim must match. Input = {input_ft}, output = {output_ft}'

        self.pad = nn.ReplicationPad2d(int((kernel_size - 1) / 2))
        self.conv_layer = nn.Conv2d(
            input_ft,
            output_ft,
            kernel_size
        )

        self.non_linearity = non_linearity

        with torch.no_grad():
            self.conv_layer.weight.data = self.conv_layer.weight.data * 0.
            self.conv_layer.bias.data = self.conv_layer.bias.data * 0.

    def forward(self, x: Tensor) -> Tensor:
        return self.non_linearity(self.conv_layer(self.pad(x)) + x)

class ModConv(nn.Module):
    def __init__(
        self,
        in_channels,
        hid_channels,
        out_channels,
        mod_layer,
    ):
        super().__init__()
        self.residual = False
        self.hid_channels=hid_channels
        self.hid_layer=mod_layer
      

        self.conv1_1 = SynthesisLayer(in_channels,hid_channels,1,nn.GELU())
        self.conv1_2 = SynthesisLayer(hid_channels,3,1,nn.GELU())
        self.conv2_1 = SynthesisResidualLayer(3,3,3,nn.GELU())
        self.conv2_2 = SynthesisResidualLayer(3,3,3)

       
    def get_param(self) -> OrderedDict[str, Tensor]:
       
        return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
    def set_param(self, param: OrderedDict[str, Tensor]) -> None:
        
        self.load_state_dict(param)

    def forward(self, x):
        
        out_0=self.conv1_1(x) 
        out_1=self.conv1_2(out_0)
      
        out_2=self.conv2_1(out_1)
        out_3=self.conv2_2(out_2)
        
        return out_3
class LocallyConnectedBlock(nn.Module):
    def __init__(self, in_channels, global_hid_channels, local_hid_channels, out_channels, mod_layer,target_mask_list):
        super().__init__()
        self.n_regions = len(target_mask_list)
        self.net_list = nn.ModuleList()

        for i in range(self.n_regions):
            net = nn.Sequential(
                SynthesisLayer(2, local_hid_channels, 1, nn.GELU()),
                SynthesisResidualLayer(local_hid_channels, local_hid_channels, 1, nn.GELU()),
                SynthesisResidualLayer(local_hid_channels, local_hid_channels, 1, nn.GELU()),
                SynthesisResidualLayer(local_hid_channels, 3, 1)
            )
            self.net_list.append(net)

    def get_param(self) -> OrderedDict[str, Tensor]:
       
        return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
    def set_param(self, param: OrderedDict[str, Tensor]) -> None:
        
        self.load_state_dict(param)
    
    def forward(self,inputs_per_region):
        
        outputs = []
        for i in range(self.n_regions):
            output = self.net_list[i](inputs_per_region[i])
            outputs.append(output)
        return outputs
    
class LocalGlobalBlock(LocallyConnectedBlock):
    def __init__(self, in_channels, global_hid_channels,local_hid_channels, out_channels, mod_layer,target_mask_list):
        super().__init__(in_channels, global_hid_channels, local_hid_channels, out_channels, mod_layer,target_mask_list)
        
        self.target_mask_list = target_mask_list
        self.agg_func = []

        self.agg_func.append(nn.Sequential(SynthesisLayer(global_hid_channels+6,3,1,nn.GELU())))
        self.agg_func.append(nn.Sequential(SynthesisLayer(global_hid_channels+9,3,1,nn.GELU())))
        self.agg_func.append(nn.Sequential(SynthesisLayer(global_hid_channels+12,3,1,nn.GELU())))
        self.agg_func = nn.Sequential(*self.agg_func)


        self.full_net = []
        self.full_net.append(SynthesisLayer(in_channels,global_hid_channels,1,nn.GELU()))
        self.full_net.append(SynthesisLayer(global_hid_channels,3,1,nn.GELU()))
        self.full_net.append(SynthesisResidualLayer(3,3,3,nn.GELU()))
        self.full_net.append(SynthesisResidualLayer(3,3,3,nn.GELU()))
      

        self.full_net = nn.Sequential(*self.full_net)
    def get_param(self) -> OrderedDict[str, Tensor]:
        
        return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
    def set_param(self, param: OrderedDict[str, Tensor]) -> None:
        
        self.load_state_dict(param)
    def forward(self, coordinate,combined_latent):
        coordinate = coordinate

        all_outputs = []
        out_full = []
     
        local_layer_input = [coordinate.clone() for _ in self.target_mask_list]
        full_layer_input = combined_latent
        local_layer_output = []
      
        
        for full_layer in self.full_net:
          full_layer_input = full_layer(full_layer_input)
          all_outputs.append(full_layer_input)
        
        out_full.append(torch.cat(all_outputs[:2], dim=1)) 
        out_full.append(torch.cat(all_outputs[:3], dim=1))
        out_full.append(torch.cat(all_outputs, dim=1))             
        
        device =  out_full[0].device 
        
        self.target_mask_list = [m.to(device).squeeze().bool() for m in self.target_mask_list]

        for layer_id, (net_tuple, agg_layer) in enumerate(zip(zip(*self.net_list), self.agg_func)):
            for region_id, mask in enumerate(self.target_mask_list):
               
                local_layer_input[region_id] = net_tuple[region_id](local_layer_input[region_id])

                
                output_full_local = torch.where(
                    mask, out_full[layer_id], torch.zeros_like(out_full[layer_id], device=device)
                )

                local_layer_input[region_id] = agg_layer(
                    torch.cat([local_layer_input[region_id], output_full_local], dim=-3)
                )
        
        net_last_layers = [net[-1] for net in self.net_list]
        for region_id in range(len(self.target_mask_list)):
            local_layer_input[region_id] = net_last_layers[region_id](local_layer_input[region_id])

        return local_layer_input
   

class ReReIC(nn.Module):
    def __init__(self, args,target_mask_list,log2_sigma,saliency_tensor):
        super().__init__()
        self.net = []
        
        self.h = args.patch_h
        self.w = args.patch_w
        self.target_mask_list = target_mask_list
        self.hidden_layers=2
        self.log2_sigma=log2_sigma
        self.saliency_tensor=saliency_tensor
        
        self.pe_flag=0
        
        self.upsampling_2d = Upsampling(
            args.local_upsampling_kernel_size, args.static_upsampling_kernel,args.highest_flag
        )
        
        self.dim_arm=args.dim_arm_mod
        self.n_hidden_layers_arm=2
        self.arm = Arm(args.context_arm,args.dim_arm_mod, self.n_hidden_layers_arm)
        
        
        self.quantizer_type="softround"
        self.quantizer_noise_type="kumaraswamy"
        self.soft_round_temperature=0.3
        self.noise_parameter=2.0 
        max_mask_size = 9
       
        self.modulation_base_number=args.mod_base
       
        self.fact_shape=[]
        if args.highest_flag==1:
            for i in range (self.modulation_base_number):
                self.fact_shape.append((self.h//(2**i),self.w//(2**i)))
        else:
            for i in range (self.modulation_base_number):
                self.fact_shape.append((self.h//(2**(i+1)),self.w//(2**(i+1))))
        self.fact_shape.reverse()
        max_context_pixel = int((max_mask_size**2 - 1) / 2)
        assert self.dim_arm <= max_context_pixel, (
            f"You can not have more context pixels "
            f" than {max_context_pixel}. Found {self.dim_arm}"
        )
        
        self.mask_size=9
        self.encoder_gains_sf=16
        print('Quantizer parameter: encoding gain ',self.encoder_gains_sf)

       
        self.all_pix_num=self.h*self.w//args.scale//args.scale
        print('total pixel:',self.all_pix_num)
        
        self.register_buffer(
            "non_zero_pixel_ctx_index",
            _get_non_zero_pixel_ctx_index(args.context_arm),
            persistent=False,
        )
        
        self.latent_factor=args.latent_factor

        self.conv_mod = LocalGlobalBlock(in_channels=2*self.modulation_base_number, global_hid_channels=args.sythesis_features,local_hid_channels=3, out_channels=self.hidden_layers+1, mod_layer=args.mod_hid_layer, target_mask_list=self.target_mask_list)

        self.modules_to_send=['arm','conv_mod','upsampling_2d']

        self.nn_q_step: Dict[str, DescriptorNN] = {
            k: {"weight": None, "bias": None} for k in self.modules_to_send
        }
        self.nn_expgol_cnt: Dict[str, DescriptorNN] = {
            k: {"weight": None, "bias": None} for k in self.modules_to_send
        }
        self.modulation_sf= nn.ParameterList()
        
        for layer_idx in range(self.modulation_base_number):
            mod_shape=self.fact_shape[layer_idx]
            shits =  nn.Parameter(torch.zeros(args.batch_size,1,  mod_shape[0], mod_shape[1])).cuda()#.requires_grad=True
            self.modulation_sf.append(shits)
            print('Get Mod with shape',shits.shape,'at layer:',layer_idx+1)
        self.noise_stddev=1
        for i, param in enumerate(self.modulation_sf):
            self.register_buffer(f"latent_gaussian_{i}", torch.randn(param.shape, device=param.device, dtype=param.dtype)) #* (self.noise_stddev))
        
        
    def quantize_all_latent(self,latent,coords):
  
        q_shifts_all_for_conv=[]
       
        for id in range(len(latent)):
           
            q_shifts_id = quantize(
                            latent[id] * self.encoder_gains_sf,
                            self.quantizer_noise_type if self.training else "none",
                            self.quantizer_type if self.training else "hardround",
                            self.soft_round_temperature,
                            self.noise_parameter,)

            
            q_shifts_all_for_conv.append(q_shifts_id)
        
        q_upsample_conv=(self.upsampling_2d(q_shifts_all_for_conv))
        random_upsample_conv=(self.upsampling_2d(self.latent_gaussian_list))
        conv_in=torch.cat((q_upsample_conv,random_upsample_conv),dim=1)

        local_layer_output=self.conv_mod(coords, conv_in)

        return q_shifts_all_for_conv,local_layer_output
    
  
    def get_param(self):
       
      
        param = OrderedDict()
       
        param.update({f"conv_mod.{k}": v for k, v in self.conv_mod.get_param().items()})
        param.update({f"arm.{k}": v for k, v in self.arm.get_param().items()})
        param.update({f"upsampling_2d.{k}": v for k, v in self.upsampling_2d.get_param().items()})
        param.update({f"modulation_sf.{i}": v for i, v in enumerate(self.modulation_sf)})
        return param

        
        
    def set_param(self, param):
       
        
       
        conv_mod_param = {k[len("conv_mod.") :]: v for k, v in param.items() if k.startswith("conv_mod.")}
        arm_param = {k[len("arm.") :]: v for k, v in param.items() if k.startswith("arm.")}
        upsampling_param = {k[len("upsampling_2d.") :]: v for k, v in param.items() if k.startswith("upsampling_2d.")}

        
        self.conv_mod.set_param(conv_mod_param)
        self.arm.set_param(arm_param)
        self.upsampling_2d.set_param(upsampling_param)
        modulation_sf_param = {int(k.split(".")[1]): v for k, v in param.items() if k.startswith("modulation_sf.")}
        for i, v in modulation_sf_param.items():
            self.modulation_sf[i].data.copy_(v.data)  
        
    def estimate_rate(self, decoder_side_latent,arm_model):
        flat_context = torch.cat(
            [
                _get_neighbor(spatial_latent_i, self.mask_size, self.non_zero_pixel_ctx_index)
                for i,spatial_latent_i in enumerate(decoder_side_latent)
            ],
            dim=0,
        )
        
        flat_latent = torch.cat(
            [spatial_latent_i.view(-1) for i,spatial_latent_i in enumerate(decoder_side_latent)],
            dim=0
        )
        
        flat_context_in=flat_context.unsqueeze(0).transpose(1, 2)
       
        flat_mu, flat_scale, flat_log_scale__ = arm_model(flat_context_in)
        proba = torch.clamp_min(
            _laplace_cdf(flat_latent + 0.5, flat_mu, flat_scale)
            - _laplace_cdf(flat_latent - 0.5, flat_mu, flat_scale),
            min=2**-16,  
        )
        flat_rate = -torch.log2(proba)
        return flat_rate
    def get_network_rate(self):
        
        rate_per_module: DescriptorCoolChic = {
            module_name: {"weight": 0.0, "bias": 0.0}
            for module_name in self.modules_to_send
        }

        for module_name in self.modules_to_send:
            cur_module = getattr(self, module_name)
            rate_per_module[module_name] = measure_expgolomb_rate(
                cur_module,
                self.nn_q_step.get(module_name),
                self.nn_expgol_cnt.get(module_name),
            )
        return rate_per_module

    
    def forward(self, coords):
        self.latent_gaussian_list = [getattr(self, f"latent_gaussian_{i}") for i in range(self.modulation_base_number)] 
        q_shifts_all_viewed,local_layer_output=self.quantize_all_latent(self.modulation_sf,coords)
        flat_rate= self.estimate_rate(q_shifts_all_viewed,self.arm)
        device = local_layer_output[0].device
        batch_size = coords.shape[0]
        total_length = self.h * self.w
        concatenated_input = torch.zeros(batch_size, 3, total_length, device=device)

        for local_output, mask in zip(local_layer_output,self.target_mask_list):
            mask = mask.flatten().bool().to(device)
            local_output = local_output.view(batch_size, 3, -1)[:, :, mask]
            concatenated_input[:, :, mask] = local_output
        
        return concatenated_input.permute(0,2,1),flat_rate
