import torch
from scipy.stats import norm,truncnorm
from functools import reduce
from scipy.special import betainc
import numpy as np
from Crypto.Cipher import ChaCha20
from Crypto.Random import get_random_bytes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import math
# from dreamsim import dreamsim

from our_utils import *

class Gaussian_Shading_chacha:
    def __init__(self, ch_factor, hw_factor, fpr, user_number, args):
        self.ch = ch_factor
        self.hw = hw_factor
        self.nonce = None
        self.key = None
        self.watermark = None
        self.latentlength = 4 * 64 * 64
        self.n_split = args.n_split
        self.bits_required = math.floor(math.log2(self.n_split-1)) + 1
        self.marklength = self.latentlength//(self.ch * self.hw * self.hw) * self.bits_required
        
        self.no_encrypt = args.no_encrypt

        self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2
        self.tp_onebit_count = 0
        self.tp_bits_count = 0
        self.tau_onebit = None
        self.tau_bits = None
        
        
        
        percentiles = np.linspace(0, 1, self.n_split + 1)[1:-1]
        self.z_values = norm.ppf(percentiles).reshape((1, -1))

        for i in range(self.marklength):
            fpr_onebit = betainc(i+1, self.marklength-i, 0.5)
            fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number
            if fpr_onebit <= fpr and self.tau_onebit is None:
                self.tau_onebit = i / self.marklength
            if fpr_bits <= fpr and self.tau_bits is None:
                self.tau_bits = i / self.marklength
        print(self.tau_onebit, self.tau_bits)

    def stream_key_encrypt(self, sd):
        self.key = get_random_bytes(32)
        self.nonce = get_random_bytes(12)
        cipher = ChaCha20.new(key=self.key, nonce=self.nonce)
        m_byte = cipher.encrypt(np.packbits(sd).tobytes())
        m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))
        return m_bit
    
    def stream_key_encrypt_aes(self, sd):
        self.key = get_random_bytes(32)
        self.nonce = get_random_bytes(16)
        cipher = Cipher(algorithms.AES(self.key), modes.CTR(self.nonce), backend=default_backend())
        encryptor = cipher.encryptor()
        m_byte = encryptor.update(np.packbits(sd).tobytes()) + encryptor.finalize()
        m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))
        return m_bit
    
    def stream_key_decrypt_aes(self, reversed_m):
        cipher = Cipher(algorithms.AES(self.key), modes.CTR(self.nonce), backend=default_backend())
        decryptor = cipher.decryptor()
        sd_byte = decryptor.update(np.packbits(reversed_m).tobytes()) + decryptor.finalize()
        sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
        sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4*self.bits_required, 64, 64).to(torch.uint8)
        return sd_tensor.cuda()

    def truncSampling(self, message):
        z = np.zeros(self.latentlength)
        denominator = float(self.n_split)
        ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
        for i in range(self.latentlength):
            dec_mes = reduce(lambda a, b: 2 * a + b, message[i*self.bits_required : (i+1)*self.bits_required])
            dec_mes = int(dec_mes)
            z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
        z = torch.from_numpy(z).reshape(1, 4, 64, 64).half()
        return z.cuda()

    def create_watermark_and_return_w(self):
        self.watermark = torch.randint(0, 2, [1, 4 // self.ch * self.bits_required, 64 // self.hw, 64 // self.hw]).cuda()
        sd = self.watermark.repeat(1,self.ch, self.hw, self.hw)
        if self.no_encrypt:
            m = sd.flatten().cpu().numpy()
        else:
            m = self.stream_key_encrypt(sd.flatten().cpu().numpy())
            # m = self.stream_key_encrypt_aes(sd.flatten().cpu().numpy())
        w = self.truncSampling(m)
        return w

    def stream_key_decrypt(self, reversed_m):
        cipher = ChaCha20.new(key=self.key, nonce=self.nonce)
        sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
        sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
        sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4*self.bits_required, 64, 64).to(torch.uint8)
        return sd_tensor.cuda()

    def diffusion_inverse(self,watermark_r):
        ch_stride = 4 // self.ch*self.bits_required
        hw_stride = 64 // self.hw
        ch_list = [ch_stride] * self.ch
        hw_list = [hw_stride] * self.hw
        split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0)
        split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
        split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
        vote = torch.sum(split_dim3, dim=0).clone()
        vote[vote <= self.threshold] = 0
        vote[vote > self.threshold] = 1
        return vote
    
    def reversed_to_m(self, reversed_w):
        reversed_w = reversed_w.flatten().cpu().numpy()
        reversed_w_expanded = np.expand_dims(reversed_w, axis=-1)
        comparison_matrix = reversed_w_expanded > self.z_values
        
        partitions = np.sum(comparison_matrix, axis=-1)
        reversed_m = np.unpackbits(partitions.astype(np.uint8)[:, None], axis=-1, count=self.bits_required, bitorder='little')
        reversed_m = reversed_m[:, ::-1]
        reversed_m = reversed_m.reshape(reversed_w.shape[0] * self.bits_required)
        return reversed_m

    def eval_watermark(self, reversed_w):
        reversed_m = self.reversed_to_m(reversed_w)
        if self.no_encrypt:
            reversed_sd = torch.from_numpy(reversed_m).reshape(1, 4*self.bits_required, 64, 64).to(torch.uint8).cuda()
        else:
            reversed_sd = self.stream_key_decrypt(reversed_m)
            # reversed_sd = self.stream_key_decrypt_aes(reversed_m)
        # reversed_m = (reversed_w > 0).int()
        # reversed_sd = self.stream_key_decrypt(reversed_m.flatten().cpu().numpy())
        reversed_watermark = self.diffusion_inverse(reversed_sd)
        correct = (reversed_watermark == self.watermark).float().mean().item()
        if correct >= self.tau_onebit:
            self.tp_onebit_count = self.tp_onebit_count+1
        if correct >= self.tau_bits:
            self.tp_bits_count = self.tp_bits_count + 1
        return correct

    def get_tpr(self):
        return self.tp_onebit_count, self.tp_bits_count


from scipy.fftpack import fft2, ifft2, fftshift, ifftshift

class ring:
    def __init__(self, ch_factor, hw_factor, fpr, user_number, args):
        self.ch = ch_factor
        self.hw = hw_factor
        self.nonce = None
        self.key = None
        self.watermark = None
        
        self.n_split = args.n_split
        self.bits_required = math.floor(math.log2(self.n_split-1)) + 1
        
        self.channel = args.channel
        self.w_channels = list(range(4))
        self.w_channels.remove(self.channel)
        
        
        self.latentlength = len(self.w_channels) * 64 * 64
        self.marklength = self.latentlength//(self.ch * self.hw * self.hw) * self.bits_required
        self.shape = [1, 4, 64, 64]
        
        self.no_encrypt = args.no_encrypt
        
        self.mask_list = []
        for i in range(args.w_r_start, args.w_r_end, -args.w_r_interval):
            if args.disk:
                mask_i = create_round_ring_disk_mask(64, radius=i)
            else:
                mask_i = create_round_ring_mask(64, radius=i)
            self.mask_list.append(mask_i)
        self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2
        self.tp_onebit_count = 0
        self.tp_bits_count = 0
        self.tau_onebit = None
        self.tau_bits = None
        
        percentiles = np.linspace(0, 1, self.n_split + 1)[1:-1]
        self.z_values = norm.ppf(percentiles).reshape((1, -1))
        
        
        # for ring
        self.ring_threshold = args.ring_threshold
        if self.ring_threshold>0:
            self.z_values_ring = norm.ppf([self.ring_threshold]).reshape((1, -1))
        else:
            self.z_values_ring = norm.ppf(percentiles).reshape((1, -1))
    

        for i in range(self.marklength):
            fpr_onebit = betainc(i+1, self.marklength-i, 0.5)
            fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number
            if fpr_onebit <= fpr and self.tau_onebit is None:
                self.tau_onebit = i / self.marklength
            if fpr_bits <= fpr and self.tau_bits is None:
                self.tau_bits = i / self.marklength
        print(self.tau_onebit, self.tau_bits)

    def stream_key_encrypt(self, sd):
        self.key = get_random_bytes(32)
        self.nonce = get_random_bytes(12)
        cipher = ChaCha20.new(key=self.key, nonce=self.nonce)
        m_byte = cipher.encrypt(np.packbits(sd).tobytes())
        m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))
        return m_bit
    
    def stream_key_encrypt_aes(self, sd):
        self.key = get_random_bytes(32)
        self.nonce = get_random_bytes(16)
        cipher = Cipher(algorithms.AES(self.key), modes.CTR(self.nonce), backend=default_backend())
        encryptor = cipher.encryptor()
        m_byte = encryptor.update(np.packbits(sd).tobytes()) + encryptor.finalize()
        m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))
        return m_bit
    
    def stream_key_decrypt_aes(self, reversed_m):
        cipher = Cipher(algorithms.AES(self.key), modes.CTR(self.nonce), backend=default_backend())
        decryptor = cipher.decryptor()
        sd_byte = decryptor.update(np.packbits(reversed_m).tobytes()) + decryptor.finalize()
        sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
        sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4*self.bits_required, 64, 64).to(torch.uint8)
        return sd_tensor.cuda()

    def truncSampling(self, message):
        z = np.zeros(self.latentlength)
        denominator = float(self.n_split)
        ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
        for i in range(self.latentlength):
            dec_mes = reduce(lambda a, b: 2 * a + b, message[i*self.bits_required : (i+1)*self.bits_required])
            dec_mes = int(dec_mes)
            z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
        z = torch.from_numpy(z).reshape(1, len(self.w_channels), 64, 64).half()
        return z.cuda()
    
    def truncSampling_ring(self):
        # z = np.zeros(self.latentlength)
        w_np = np.random.randn(*self.shape)
        denominator = float(self.n_split)
        if self.ring_threshold>0: # >0.5
            ppf_0 = [norm.ppf(0), norm.ppf(self.ring_threshold), norm.ppf(1)]
            ppf_1 = [norm.ppf(0), norm.ppf(1-self.ring_threshold), norm.ppf(1)]
        else:
            ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
        for i, mask in enumerate(self.mask_list):
            n_mask = np.sum(mask)
            mes = self.s[i]
            z = np.zeros(n_mask)
            if self.ring_threshold>0:    
                ppf = ppf_0 if mes==0 else ppf_1
            for j in range(n_mask):
                z[j] = truncnorm.rvs(ppf[mes], ppf[mes + 1])
            # the w only has 1 batch size and 1 channel
            w_np[:, :, mask] = z
        w = torch.from_numpy(w_np).half().cuda()
        return w

    def create_watermark_and_return_w(self):
        self.watermark = torch.randint(0, 2, [1, len(self.w_channels) // self.ch * self.bits_required, 64 // self.hw, 64 // self.hw]).cuda()
        sd = self.watermark.repeat(1,self.ch, self.hw, self.hw)
        if self.no_encrypt:
            m = sd.flatten().cpu().numpy()
        else:
            m = self.stream_key_encrypt(sd.flatten().cpu().numpy())
            # m = self.stream_key_encrypt_aes(sd.flatten().cpu().numpy())
        w_ = self.truncSampling(m)
        
        # if self.ring_threshold>0 and self.ring_threshold<=0.5:
        #     self.s = [1]*len(self.mask_list)
        # elif self.ring_threshold>0.5:
        #     self.s = [0]*len(self.mask_list)
        # else:
        # uniform random
        self.s = np.random.randint(0, 2, len(self.mask_list)*self.bits_required)
        # self.s = [0]*len(self.mask_list)*self.bits_required
        self.z_values_ring_list = []
        for v in self.s:
            if v==0:
                self.z_values_ring_list.append(norm.ppf([self.ring_threshold]).reshape((1, -1)))
            else:
                self.z_values_ring_list.append(norm.ppf([1-self.ring_threshold]).reshape((1, -1)))
        # self.s_ori = np.random.randint(0, 2, len(self.mask_list)*self.bits_required)
        # self.s = self.stream_key_encrypt(self.s_ori)
        # self.s_ori = torch.tensor(self.s_ori)
        self.ring_watermark = []
        for i, mask in enumerate(self.mask_list):
            n_mask = np.sum(mask)
            self.ring_watermark.extend([self.s[i]]*n_mask)
            # temp = [self.s[i]]
            # for _ in range(1, n_mask):
            #     next_value = (temp[-1]+1)%self.n_split
            #     temp.append(next_value)
            # self.ring_watermark.extend(temp)
            
        self.ring_watermark = torch.tensor(self.ring_watermark)
        w = self.truncSampling_ring()
                
        # insert the w_ into w
        w[:, :self.channel, :, :] = w_[:, :self.channel, :, :]
        w[:, self.channel + 1:, :, :] = w_[:, self.channel:, :, :]
        
        return w

    def stream_key_decrypt(self, reversed_m):
        cipher = ChaCha20.new(key=self.key, nonce=self.nonce)
        sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
        sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
        sd_tensor = torch.from_numpy(sd_bit).reshape(1, len(self.w_channels)*self.bits_required, 64, 64).to(torch.uint8)
        return sd_tensor.cuda()
    
    def stream_key_decrypt_ring(self, reversed_m):
        initial_length = len(reversed_m)
        cipher = ChaCha20.new(key=self.key, nonce=self.nonce)
        sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
        sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
        if len(sd_bit)>initial_length:
            sd_bit = sd_bit[:initial_length]
        sd_tensor = torch.from_numpy(sd_bit).reshape(-1).to(torch.uint8)
        return sd_tensor

    def diffusion_inverse(self,watermark_r):
        ch_stride = len(self.w_channels) // self.ch*self.bits_required
        hw_stride = 64 // self.hw
        ch_list = [ch_stride] * self.ch
        hw_list = [hw_stride] * self.hw
        split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0)
        split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
        split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
        vote = torch.sum(split_dim3, dim=0).clone()
        vote[vote <= self.threshold] = 0
        vote[vote > self.threshold] = 1
        return vote
    
    def reversed_to_m(self, reversed_w):
        reversed_w = reversed_w.flatten().cpu().numpy()
        reversed_w_expanded = np.expand_dims(reversed_w, axis=-1)
        comparison_matrix = reversed_w_expanded > self.z_values
        
        partitions = np.sum(comparison_matrix, axis=-1)
        reversed_m = np.unpackbits(partitions.astype(np.uint8)[:, None], axis=-1, count=self.bits_required, bitorder='little')
        reversed_m = reversed_m[:, ::-1]
        reversed_m = reversed_m.reshape(reversed_w.shape[0] * self.bits_required)
        return reversed_m
    
    def reversed_to_m_ring(self, reversed_w, i):
        reversed_w = reversed_w.flatten().cpu().numpy()
        reversed_w_expanded = np.expand_dims(reversed_w, axis=-1)
        if self.ring_threshold>0:
            comparison_matrix = reversed_w_expanded > self.z_values_ring_list[i]
        else:
            comparison_matrix = reversed_w_expanded > self.z_values_ring
        
        partitions = np.sum(comparison_matrix, axis=-1)
        reversed_m = np.unpackbits(partitions.astype(np.uint8)[:, None], axis=-1, count=self.bits_required, bitorder='little')
        reversed_m = reversed_m[:, ::-1]
        reversed_m = reversed_m.reshape(reversed_w.shape[0] * self.bits_required)
        return reversed_m

    def eval_watermark(self, reversed_w):
        w_part1 = reversed_w[:, :self.channel, :, :]
        w_part2 = reversed_w[:, self.channel + 1:, :, :]
        reversed_w_ = torch.cat((w_part1, w_part2), dim=1)
        reversed_m = self.reversed_to_m(reversed_w_)
        if self.no_encrypt:
            reversed_sd = torch.from_numpy(reversed_m).reshape(1, len(self.w_channels)*self.bits_required, 64, 64).to(torch.uint8).cuda()
        else:
            reversed_sd = self.stream_key_decrypt(reversed_m)
            # reversed_sd = self.stream_key_decrypt_aes(reversed_m)
        reversed_watermark = self.diffusion_inverse(reversed_sd)
        correct = (reversed_watermark == self.watermark).float().mean().item()
        
        
        reversed_ring_watermark = []
        # reversed_s = []
        for i, mask in enumerate(self.mask_list):
            reversed_watermark_mask = self.reversed_to_m_ring(reversed_w[:, self.channel, mask], i)
            # s = 1 if np.sum(reversed_watermark_mask)>len(reversed_watermark_mask)//2 else 0
            # reversed_s.append(s)
            # correct_mask = np.mean((reversed_watermark_mask == self.s[i]))
            # if correct_mask>=0.55:
            #     correct+=1/len(self.mask_list)
            reversed_ring_watermark.extend(reversed_watermark_mask)
        reversed_ring_watermark = torch.tensor(reversed_ring_watermark)
        correct_ring = (reversed_ring_watermark == self.ring_watermark).float().mean().item()
        # reversed_s_ori = self.stream_key_decrypt_ring(reversed_s)
        # correct_ring = (reversed_s_ori == self.s_ori).float().mean().item()
        print(correct, correct_ring)
        correct = max(correct, correct_ring)
        
        if correct >= self.tau_onebit:
            self.tp_onebit_count = self.tp_onebit_count+1
        if correct >= self.tau_bits:
            self.tp_bits_count = self.tp_bits_count + 1
        return correct

    def get_tpr(self):
        return self.tp_onebit_count, self.tp_bits_count
    






class Gaussian_Shading:
    def __init__(self, ch_factor, hw_factor, fpr, user_number, args):
        self.ch = ch_factor
        self.hw = hw_factor
        self.key = None
        self.watermark = None
        self.latentlength = 4 * 64 * 64
        self.marklength = self.latentlength//(self.ch * self.hw * self.hw)

        self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2
        self.tp_onebit_count = 0
        self.tp_bits_count = 0
        self.tau_onebit = None
        self.tau_bits = None
        
        self.split_num = 2

        for i in range(self.marklength):
            fpr_onebit = betainc(i+1, self.marklength-i, 0.5)
            fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number
            if fpr_onebit <= fpr and self.tau_onebit is None:
                self.tau_onebit = i / self.marklength
            if fpr_bits <= fpr and self.tau_bits is None:
                self.tau_bits = i / self.marklength

    def truncSampling(self, message):
        z = np.zeros(self.latentlength)
        denominator = 2.0
        ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
        for i in range(self.latentlength):
            dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
            dec_mes = int(dec_mes)
            z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
        z = torch.from_numpy(z).reshape(1, 4, 64, 64).half()
        return z.cuda()

    def create_watermark_and_return_w(self):
        self.key = torch.randint(0, 2, [1, 4, 64, 64]).cuda()
        self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.hw, 64 // self.hw]).cuda()
        sd = self.watermark.repeat(1,self.ch,self.hw,self.hw)
        m = ((sd + self.key) % 2).flatten().cpu().numpy()
        w = self.truncSampling(m)
        return w

    def diffusion_inverse(self,watermark_sd):
        ch_stride = 4 // self.ch
        hw_stride = 64 // self.hw
        ch_list = [ch_stride] * self.ch
        hw_list = [hw_stride] * self.hw
        split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0)
        split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
        split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
        vote = torch.sum(split_dim3, dim=0).clone()
        vote[vote <= self.threshold] = 0
        vote[vote > self.threshold] = 1
        return vote

    def eval_watermark(self, reversed_m):
        reversed_m = (reversed_m > 0).int()
        reversed_sd = (reversed_m + self.key) % 2
        reversed_watermark = self.diffusion_inverse(reversed_sd)
        correct = (reversed_watermark == self.watermark).float().mean().item()
        if correct >= self.tau_onebit:
            self.tp_onebit_count = self.tp_onebit_count+1
        if correct >= self.tau_bits:
            self.tp_bits_count = self.tp_bits_count + 1
        return correct

    def get_tpr(self):
        return self.tp_onebit_count, self.tp_bits_count


