import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch.nn as nn
import torch
import json
from PIL import Image

import re
import util.config as c
from typing import Any, Dict

device = torch.device("cuda")



def gauss_noise(shape):
    noise = torch.zeros(shape).cuda()
    for i in range(noise.shape[0]):
        noise[i] = torch.randn(noise[i].shape).cuda()
    return noise
    
def l1_loss(output, bicubic_image):
    loss_fn = torch.nn.L1Loss(reduce=True, size_average=True)
    loss = loss_fn(output, bicubic_image)
    return loss

def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def load(name,net):
    state_dicts = torch.load(name)
    network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
    net.load_state_dict(network_state_dict)

class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.register_buffer('mean', torch.Tensor(mean))
        self.register_buffer('std', torch.Tensor(std))

    def forward(self, input):
        # Broadcasting
        input = input / 255.0
        mean = self.mean.reshape(1, 3, 1, 1)
        std = self.std.reshape(1, 3, 1, 1)
        return (input - mean) / std


def normal_r(output_r):
    r_max = torch.max(output_r)
    r_min = torch.min(output_r)
    r_mean = r_max - r_min
    output_r = (output_r - r_min) / r_mean
    return output_r

def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

def imglist(path,mat):

    dirpath = []
    for parent,dirname,filenames in os.walk(path):
        for filename in filenames:
            if(filename.endswith(mat)):
                dirpath.append(os.path.join(parent,filename))

    return dirpath

def l_cal(img1,img2):
    noise = (img1 - img2).flatten(start_dim=0)
    l2 = torch.sum(torch.pow(torch.norm(noise, p=2, dim=0), 2))
    l_inf = torch.sum(torch.norm(noise, p=float('inf'), dim=0))
    return l2,l_inf


def latest_checkpoint() -> int:
    """Returns latest checkpoint."""
    if os.path.exists(c.gan_checkpoints_dir):
        all_chkpts = "".join(os.listdir(c.gan_checkpoints_dir))
        if len(all_chkpts) > 0:
            latest = max(map(int, re.findall("\d+", all_chkpts)))
        else:
            latest = None
    else:
        latest = None
    return latest    

def weights_init_target(param: Any) -> None:
    """Initializes weights of Conv and fully connected."""

    if isinstance(param, nn.Conv2d):
        torch.nn.init.xavier_uniform_(param.weight.data)
        if param.bias is not None:
            torch.nn.init.constant_(param.bias.data, 0.2)
    elif isinstance(param, nn.Linear):
        torch.nn.init.normal_(param.weight.data, mean=0.0, std=0.01)
        torch.nn.init.constant_(param.bias.data, 0.0)


def initWeights(module):
    if type(module) == nn.Conv2d:
        if module.weight.requires_grad:
            nn.init.kaiming_normal_(module.weight.data, mode="fan_in", nonlinearity="relu")

    if type(module) == nn.Linear:
        nn.init.normal_(module.weight.data, mean=0, std=0.01)
        nn.init.constant_(module.bias.data, val=0)