import numpy as np
import torch
import os
import pickle
import string
import random
from torch.distributions import Chi2
import galois
GF2 = galois.GF(2)

def min_cycle_length_galois(matrix):
    L = 6
    n, m = matrix.shape
    m_matrix = np.identity(m).astype(np.uint8)
    m_matrix = np.stack([m_matrix] * (n // m), axis=0).reshape((-1, m))
    matrix = np.stack([m_matrix, matrix], axis=-1).reshape((-1, 2 * m)).astype(np.uint8)
    matrix = GF2(matrix)

    basis = []
    combinations = []

    min_len = float('inf')

    for i in range(n):
        vec = matrix[i].copy()
        comb = [i]

        for j, b in enumerate(basis):
            pivot = np.argmax(b)
            if vec[pivot] == 1:
                vec ^= b
                comb += combinations[j]
            if len(comb) > L:
                break
        if len(comb) > L:
            continue
        if np.any(vec):
            basis.append(vec)
            combinations.append(comb)
        else:
            min_len = min(min_len, len(comb))

    return min_len if min_len != float('inf') else None

class SphericalCodes:
    """
        N: repeated times for secret message
        lm: len of secret message
        lr: len of random message
        t: sparsity of matrix R
    """
    def __init__(self, from_file, keys_path, N, lm, lr, t, K_len, batch_size, latent_shape, device, P=None):
        self.device = device
        self.ablate_signet = False
        self.ablate_rotation = False
        if from_file:
            # load previous signet
            self.from_file = from_file
            self.keys_path = keys_path

            file_name = f'{from_file}.pkl'
            file_path = os.path.join(keys_path, file_name)
            with open(file_path, 'rb') as f:
                self.P = P.to(device) if P is not None else None
                self.N, self.lm, self.lr, self.t, self.latent_shape, R_positions, K_seed, self.K_len = pickle.load(f)
                self.total_len = N * lm + lr
                self.T, self.T_inv, self.K, self.K_inv = self.construct_signet_from_file(R_positions, K_seed)
                # self.N, self.lm, self.lr, self.t, self.latent_shape, self.T, self.T_inv, self.K, self.K_inv = pickle.load(f)
                self.T = self.T.to(device)
                self.T_inv = self.T_inv.to(device)
                self.K = self.K.to(device)
                self.K_inv = self.K_inv.to(device)

        else:
            # construct new signet
            self.N = N
            self.lm = lm
            self.lr = lr
            self.t = t
            self.K_len = K_len
            self.P = P.to(device) if P is not None else None
            self.total_len = N * lm + lr
            self.latent_shape = latent_shape

            K_seed = random.getrandbits(64)
            T, T_inv, K, K_inv, R_positions = self.construct_signet(K_seed)
            self.T = T.unsqueeze(0)
            self.T_inv = T_inv.unsqueeze(0)
            self.K = K.unsqueeze(0)
            self.K_inv = K_inv.unsqueeze(0)

            os.makedirs(keys_path, exist_ok=True)
            from_file = self._random_file_name()
            from_file = f'{from_file}_{t}_{N}_{self.total_len}_{self.lr}'
            store_device = 'cpu'
            file_name = f'{from_file}.pkl'
            file_path = os.path.join(keys_path, file_name)
            with open(file_path, 'wb') as f:
                pickle.dump((N, lm, lr, t, self.latent_shape, R_positions, K_seed, self.K_len), f)
            self.from_file = from_file
            self.keys_path = keys_path
            print(f"Constructed watermark signet located at {keys_path}/{self.from_file}.pkl")
        self.batch_size = batch_size
        self.chisquare = Chi2(df=self.K_len)

    def _random_file_name(self):
        letters = string.ascii_letters + string.digits
        return ''.join(random.choice(letters) for _ in range(64))
    def _construct_R(self, N, lm, lr, t, compute_rank=False):
        if lr < N * t:
            raise ValueError('lr must be greater than N * t')

        R = np.zeros((N, lm, lr), dtype=int)
        positions = []
        for round_index in range(lm):
            perm = np.random.permutation(lr)
            selected = perm[:N*t]
            for i in range(N):
                group = np.sort(selected[i * t: (i + 1) * t])
                R[i, round_index, list(group)] = 1
                positions.append((i, round_index, list(group)))
        result = torch.from_numpy(R).reshape((-1, lr))
        if compute_rank:
            print(min_cycle_length_galois(R.reshape(-1, lr)))
        return result, positions

    def _construct_R_from_positions(self, N, lm, lr, R_positions):
        R = np.zeros((N, lm, lr), dtype=int)
        for item in R_positions:
            i, round_index, group = item
            R[i, round_index, group] = 1
        return torch.from_numpy(R).reshape((-1, lr))

    def _construct_T(self, N, lm, lr, t, P, device):
        I_s = torch.eye(N * lm, dtype=torch.float32, device=device)
        R, R_positions = self._construct_R(N, lm, lr, t)
        R = R.to(I_s)
        null_matrix = torch.zeros((lr, N * lm), dtype=torch.float32, device=device)
        if P is None:
            P = torch.eye(lr, dtype=torch.float32, device=device)
        else:
            P = P.to(device)
        P_inv = P.transpose(0, 1)
        T_top = torch.cat((I_s, R), dim=1)
        T_bottom = torch.cat((null_matrix, P), dim=1)
        T = torch.cat((T_top, T_bottom), dim=0)

        T_inv_top = torch.cat((I_s, (R @ P_inv) % 2), dim=1)
        T_inv_bottom = torch.cat((null_matrix, P_inv), dim=1)
        T_inv = torch.cat((T_inv_top, T_inv_bottom), dim=0)
        return T, T_inv, R_positions

    def _construct_T_from_positions(self, N, lm, lr, R_positions, P, device):
        I_s = torch.eye(N * lm, dtype=torch.float32, device=device)
        R = self._construct_R_from_positions(N, lm, lr, R_positions).to(I_s)
        null_matrix = torch.zeros((lr, N * lm), dtype=torch.float32, device=device)
        if P is None:
            P = torch.eye(lr, dtype=torch.float32, device=device)
        else:
            P = P.to(device)
        P_inv = P.transpose(0, 1)
        T_top = torch.cat((I_s, R), dim=1)
        T_bottom = torch.cat((null_matrix, P), dim=1)
        T = torch.cat((T_top, T_bottom), dim=0)

        T_inv_top = torch.cat((I_s, (R @ P_inv) % 2), dim=1)
        T_inv_bottom = torch.cat((null_matrix, P_inv), dim=1)
        T_inv = torch.cat((T_inv_top, T_inv_bottom), dim=0)
        return T, T_inv

    def _construct_K(self, K_len, device, K_seed):
        generator = torch.Generator()
        generator.manual_seed(K_seed)
        K_init = torch.randn((K_len, K_len), generator=generator).to(device)
        K, _ = torch.linalg.qr(K_init)
        K_inv = K.transpose(0, 1)
        return K, K_inv

    def construct_signet(self, K_seed):
        T, T_inv, R_positions = self._construct_T(self.N, self.lm, self.lr, self.t, self.P, self.device)
        K, K_inv = self._construct_K(self.K_len, self.device, K_seed)
        return T, T_inv, K, K_inv, R_positions

    def construct_signet_from_file(self, R_positions, K_seed):
        T, T_inv = self._construct_T_from_positions(self.N, self.lm, self.lr, R_positions, self.P, self.device)
        K, K_inv = self._construct_K(self.K_len, self.device, K_seed)
        return T, T_inv, K, K_inv

    def embed_watermark(self, message):
        # message: [batch_size lm]
        if len(message.shape) == 2:
            message = message.unsqueeze(-1)
        repeated_message = message.repeat(1, self.N, 1).to(self.device)
        batch_size = message.shape[0]
        random_stream = torch.randint(0, 2, (batch_size, self.lr, 1)).to(self.T)
        input_message = torch.cat((repeated_message, random_stream), dim=1)
        if not self.ablate_signet:
            out_sign = (self.T @ input_message) % 2
        else:
            out_sign = input_message
        out_sign = out_sign * 2 - 1

        if not self.ablate_rotation:
            out_sign = out_sign.reshape((batch_size, self.K_len, -1))
            out_noise = self.K @ out_sign
            chi_rand = self.chisquare.sample((batch_size*out_noise.shape[-1],)).to(out_noise).reshape((batch_size, 1, -1))
            out_noise = out_noise * torch.sqrt(chi_rand) / (self.K_len ** 0.5)
        else:
            out_noise = out_sign * torch.abs(torch.randn_like(out_sign).to(out_sign))
        return out_noise.reshape((batch_size, *self.latent_shape))

    def extract_watermark(self, pred_noise):
        batch_size = pred_noise.shape[0]
        if not self.ablate_rotation:
            pred_noise = pred_noise.reshape((batch_size, self.K_len, -1))
            pred_sign = self.K_inv @ pred_noise
        else:
            pred_sign = torch.sign(pred_noise)
        pred_sign = pred_sign.reshape((batch_size, -1, 1))
        pred_sign = (torch.sign(pred_sign) + 1) / 2
        pred_sign = torch.round(pred_sign)

        if not self.ablate_signet:
            pred_input_message = (self.T_inv @ pred_sign) % 2
        else:
            pred_input_message = pred_sign.reshape((batch_size, -1, 1))
        pred_repeated_message = pred_input_message[:, :(self.lm * self.N), :].reshape((-1, self.N, self.lm))
        pred_message = torch.round(torch.mean(pred_repeated_message, dim=1))
        return pred_message

if __name__ == '__main__':
    N = 31
    lm = 512
    lr = 512
    t = 3
    P = None
    device = 'cuda:0'
    from_file = ''
    keys_path = './keys/sph/'

    sph_codes = SphericalCodes(from_file=from_file, keys_path=keys_path, N=N, lm=lm, lr=lr, t=t, K_len=128, device=device, P=P, batch_size=1, latent_shape=(4, 64, 64))
    message = torch.randint(0, 2, (1, lm,)).to(device)
    out = sph_codes.embed_watermark(message)
    pred = sph_codes.extract_watermark(out)
    print((pred - message).abs().mean())




