import torch
import torch.nn as nn
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import torch.nn.functional as F
from MiniResNetPlus import MiniResNetPlus
from joblib import Parallel, delayed

import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, "../")) 
sys.path.insert(0, parent_dir)


import matplotlib.pyplot as plt





def generate_cifar100_like_inputs(num_samples, device="cpu"):

    
    mean = torch.tensor([0.5071, 0.4867, 0.4408]).view(3, 1, 1).to(device)
    std  = torch.tensor([0.2675, 0.2565, 0.2761]).view(3, 1, 1).to(device)

    inputs = []
    for _ in range(num_samples):
        
        x = torch.empty(3, 32, 32, device=device).uniform_(0, 1)

        
        x = (x - mean) / std

        
        inputs.append(x.unsqueeze(0))

    return inputs


def collect_activations(model, input_tensors=None, num_samples=1, device='cpu'):
    conv_inputs = []
    relu_inputs = []
    relu_outputs = []

    def conv_hook(module, input, output):
        conv_inputs.append(input[0].detach().cpu())

    def relu_hook(module, input, output):
        relu_inputs.append(input[0].detach().cpu())   
        relu_outputs.append(output.detach().cpu())    

    
    hooks = []
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d):
            hooks.append(layer.register_forward_hook(conv_hook))
        elif isinstance(layer, nn.ReLU):
            hooks.append(layer.register_forward_hook(relu_hook))

    
    input_shape = (1, 3, 32, 32)

    records = []
    i = 0
    while len(records) < num_samples:
        if input_tensors is None:
            
            x = torch.empty(input_shape).uniform_(0, 1).to(device)
            mean = torch.tensor([0.5071, 0.4867, 0.4408]).view(3, 1, 1).to(device)
            std  = torch.tensor([0.2675, 0.2565, 0.2761]).view(3, 1, 1).to(device)
            x = (x - mean) / std
            inputs = x
            print(f"[INFO] Generated sample {len(records)+1}/{num_samples} (CIFAR100-like)")
        else:
            inputs = input_tensors[i].to(device)
            if inputs.dim() == 3:
                inputs = inputs.unsqueeze(0)  # [C,H,W] -> [1,C,H,W]
            i += 1

        conv_inputs.clear()
        relu_inputs.clear()
        relu_outputs.clear()

        with torch.no_grad():
            _ = model(inputs)

        record = {
            "input_tensor": inputs.detach().cpu(),
            "conv_inputs": [t.clone() for t in conv_inputs],
            "relu_inputs": [t.clone() for t in relu_inputs],
            "relu_outputs": [t.clone() for t in relu_outputs],
        }
        records.append(record)

    
    for hook in hooks:
        hook.remove()

    return records



def collect_block_activations(model, layer_name, block_idx, input_tensors, device='cpu'):

    conv_inputs = []
    relu_inputs = []
    relu_outputs = []

    def conv_hook(module, input, output):
        conv_inputs.append(input[0].detach().cpu())

    def relu_hook(module, input, output):
        relu_inputs.append(input[0].detach().cpu())   
        relu_outputs.append(output.detach().cpu())    

   
    layer = getattr(model, layer_name)
    block = layer[block_idx]


    hooks = []
    for name, module in block.named_modules():
        if isinstance(module, nn.Conv2d):
            hooks.append(module.register_forward_hook(conv_hook))
        elif isinstance(module, nn.ReLU):
            hooks.append(module.register_forward_hook(relu_hook))

    records = []

    for input_tensor in input_tensors:
        conv_inputs.clear()
        relu_inputs.clear()
        relu_outputs.clear()

        with torch.no_grad():
            input_tensor = input_tensor.to(device)
            _ = block(input_tensor)

        records.append({
            "conv_inputs": [t.clone() for t in conv_inputs],
            "relu_inputs": [t.clone() for t in relu_inputs],
            "relu_outputs": [t.clone() for t in relu_outputs]
        })

    for h in hooks:
        h.remove()

    return records


def extract_patch(input_tensor, h_out, w_out, conv_layer):
    kH, kW = conv_layer.kernel_size
    sH, sW = conv_layer.stride
    pH, pW = conv_layer.padding
    
    h_in = h_out * sH - pH
    w_in = w_out * sW - pW
    
    h_start = max(h_in, 0)
    w_start = max(w_in, 0)
    
    
    pad_top = max(-h_in, 0)
    pad_left = max(-w_in, 0)
    pad_bottom = max(h_in + kH - input_tensor.size(2), 0)
    pad_right = max(w_in + kW - input_tensor.size(3), 0)
    
 
    padded = torch.nn.functional.pad(input_tensor, 
        (pad_left, pad_right, pad_top, pad_bottom))
    
    
    return padded[:, :, h_start:h_start+kH, w_start:w_start+kW]



def process_general_channle(n_triplets,c, l, num_params, records, conv_idx , conv_layer,in_channels, kernel_size, dead_input_channels=[]):

    
    A = []
    b = []
    bias = False
    

    for k in range(l):
        for h in range(records[0]['relu_outputs'][conv_idx].shape[2]):
            for w in range(records[0]['relu_outputs'][conv_idx].shape[3]):

                
                activations = [rec['relu_outputs'][conv_idx][0, c, h, w].item() for rec in records]
                
                
                valid_pairs = []                
                
                for i in range(3*k,3*k+3):
                    for j in range(i+1, 3*k+3):
                        if activations[i] > 0 and activations[j] > 0:
                            valid_pairs.append((i, j))
                        
                for pair in valid_pairs:
                    
                    i, j = pair
                    input_i = records[i]['conv_inputs'][conv_idx]
                    input_j = records[j]['conv_inputs'][conv_idx]


                    
                    patch_i_raw = extract_patch(input_i, h, w, conv_layer).cpu().numpy()
                    patch_j_raw = extract_patch(input_j, h, w, conv_layer).cpu().numpy()
                    patch_i_raw = patch_i_raw[0]  # shape: (16, 3, 3)
                    patch_j_raw = patch_j_raw[0]  # shape: (16, 3, 3)

                    
                    patch_i = np.delete(patch_i_raw, dead_input_channels, axis=0).flatten()
                    patch_j = np.delete(patch_j_raw, dead_input_channels, axis=0).flatten()
                    




                    delta_x = patch_i - patch_j
                    delta_x = np.append(delta_x, 0.0)  
                    
                    
                    delta_y = activations[i] - activations[j]

                    if len(A) > 0:
                        A_old = np.array(A, dtype=np.float64)
                        rank_old = np.linalg.matrix_rank(A_old)

                        A_new = np.vstack([A_old, delta_x])
                        rank_new = np.linalg.matrix_rank(A_new)

                        if rank_new > rank_old:
                            try:
                                cond_old = np.linalg.cond(A_old) if rank_old >= 2 else 1.0
                            except np.linalg.LinAlgError:
                                cond_old = np.inf

                            try:
                                cond_new = np.linalg.cond(A_new)
                            except np.linalg.LinAlgError:
                                cond_new = np.inf

                            
                            if np.isinf(cond_old) and np.isinf(cond_new):
                                A.append(delta_x)
                                b.append(delta_y)
                            elif cond_new < 1e12 and cond_new / (cond_old + 1e-8) < 50:
                                A.append(delta_x)
                                b.append(delta_y)
                    else:
                        A.append(delta_x)
                        b.append(delta_y)

                
                
                valid_pairs = []                    
                if activations[3*k] > 0 and activations[3*k+1] <= 0 and activations[3*k+2] <= 0:
                    valid_pairs.append(3*k)
                if activations[3*k] <= 0 and activations[3*k+1] > 0 and activations[3*k+2] <= 0:
                    valid_pairs.append(3*k+1)
                if activations[3*k] <= 0 and activations[3*k+1] <= 0 and activations[3*k+2] > 0:
                    valid_pairs.append(3*k+2)

                for index in valid_pairs:
                    bias = True                
                    input_i = records[index]['conv_inputs'][conv_idx]
                    patch_i_raw = extract_patch(input_i, h, w, conv_layer).cpu().numpy()  # shape: (in_channels, kh, kw)                    
                    if patch_i_raw.ndim == 4 and patch_i_raw.shape[0] == 1:
                        patch_i_raw = patch_i_raw[0]  # shape: (C, kh, kw)
                    patch_i = np.delete(patch_i_raw, dead_input_channels, axis=0).flatten()                    
                    patch_i = np.append(patch_i, 1.0)                    
                    y = activations[index]
                    A.append(patch_i)
                    b.append(y)
        A_ = np.array(A, dtype=np.float64)
        b_ = np.array(b, dtype=np.float64)
        if(len(A_) == 0 and k >= 20):
            break
        if(len(A_) == 0):
            continue
        A_reg = np.vstack([A_, np.sqrt(0)*np.eye(A_.shape[1])])
        b_reg = np.concatenate([b_, np.zeros(A_.shape[1])])
        A_rank = np.linalg.matrix_rank(A_reg)
        if((A_rank >= num_params - 1) and (bias == False) and k >= n_triplets):
            A = A_[:,:-1]
            print(f"channel {c} bias can't be recovered")
            print(f"channel {c} data triplets num = {k}")
            break
        if((A_rank >= num_params) and (bias == True) and k >= n_triplets):
            print(f"channel {c} data triplets num = {k}")
            break
        if((len(A_) == 0) and k > 20):
            break
            
    A = np.array(A, dtype=np.float64)
    b = np.array(b, dtype=np.float64)
    if(len(A) == 0):
        print(f"channel {c} cant not be recovered")
        return None       
    else:
       
        A_reg = np.vstack([A, np.sqrt(0)*np.eye(A.shape[1])])
        b_reg = np.concatenate([b, np.zeros(A.shape[1])])
        A_rank = np.linalg.matrix_rank(A_reg)
        
        if bias:
            None
        else:
            
            num_params = num_params-1
        if (A_rank >= (num_params-10)) and (bias == True) :  
            cond_num = np.linalg.cond(A_reg)
            print(f"channel {c} Condition number: {cond_num:.2e}")
            W, _, _, _ = np.linalg.lstsq(A_reg, b_reg, rcond=1e-6)
            
            recovered_w_flat = W[:-1]
            reconstructed_bias = W[-1]        
            reconstructed_weight = np.zeros((in_channels, *kernel_size), dtype=np.float32)
            valid_channels = [i for i in range(in_channels) if i not in dead_input_channels]
            recovered_weight_reshaped = recovered_w_flat.reshape(len(valid_channels), *kernel_size)
            for idx, ch in enumerate(valid_channels):
                reconstructed_weight[ch] = recovered_weight_reshaped[idx]
            return reconstructed_weight, reconstructed_bias
        elif (A_rank >= (num_params-10)) and (bias == False):
            cond_num = np.linalg.cond(A_reg)
            print(f"channel {c} Condition number: {cond_num:.2e}")
            W, _, _, _ = np.linalg.lstsq(A_reg, b_reg, rcond=1e-6)
            recovered_w_flat = W
            reconstructed_bias = torch.tensor([0])
            reconstructed_weight = np.zeros((in_channels, *kernel_size), dtype=np.float32)
            valid_channels = [i for i in range(in_channels) if i not in dead_input_channels]
            recovered_weight_reshaped = recovered_w_flat.reshape(len(valid_channels), *kernel_size)
            for idx, ch in enumerate(valid_channels):
                reconstructed_weight[ch] = recovered_weight_reshaped[idx]
            return reconstructed_weight, reconstructed_bias 
        else:
            print(f"channel {c} cant not be recovered,A_rank = {A_rank}")
            return None


def recover_general_conv_layer(n_triplets,records, conv_layer, conv_idx,dead_input_channels=[]):
    
    out_channels = conv_layer.out_channels
    in_channels = conv_layer.in_channels
    kernel_size = conv_layer.kernel_size   
    num_params = in_channels * kernel_size[0]* kernel_size[1] + 1 - len(dead_input_channels)* kernel_size[0]* kernel_size[1]
    reconstructed_weight = np.zeros((out_channels, in_channels, *kernel_size))
    reconstructed_bias = [0] * out_channels
    
    

    num_zeros = 0

    l = int(len(records)/3)


    

    results = Parallel(n_jobs=48)(delayed(process_general_channle)(n_triplets,c, l, num_params, records, conv_idx , conv_layer,in_channels, kernel_size, dead_input_channels) for c in range(out_channels))
    
    #results = process_general_channle(14, l, num_params, records, conv_idx , conv_layer,in_channels, kernel_size,dead_input_channels)

    for i in range(len(results)):
        if results[i] == None:
            num_zeros = num_zeros + 1
        else:
            reconstructed_weight[i] = results[i][0]
            reconstructed_bias[i] = results[i][1]

    return reconstructed_weight, reconstructed_bias, num_zeros



def validate_conv_bn(org_conv, org_bn, org_relu, reconstructed_conv, reconstructed_bn, reconstructed_relu, test_inputs, test_outputs):

    
    reconstructed_conv.eval()
    reconstructed_bn.eval()
    org_conv.eval()
    org_bn.eval()
    

    all_y_re = []
    all_y_org = []
    all_test = []

 
    for i, x in enumerate(test_inputs):
        x = x.to(device)
        

        with torch.no_grad():
            y_re = reconstructed_conv(x)
            y_re = y_re.unsqueeze(0)       
            y_re = reconstructed_bn(y_re)
            y_re = reconstructed_relu(y_re)
            y_re_flat = y_re.squeeze(0).cpu().numpy().flatten()
            all_y_re.append(y_re_flat)
        with torch.no_grad():
            y_org = org_conv(x)
            y_org = y_org.unsqueeze(0)
            y_org = org_bn(y_org)
            y_org = org_relu(y_org)
            y_org_flat = y_org.squeeze(0).cpu().numpy().flatten()
            all_y_org.append(y_org_flat)
        
        
        test_output = test_outputs[i].cpu().numpy().flatten()
        all_test.append(test_output)
    
 
    np.savetxt('reconstructed_outputs.txt', np.array(all_y_re), fmt='%.8f')
    np.savetxt('original_outputs.txt', np.array(all_y_org), fmt='%.8f')
    np.savetxt('test_outputs.txt', np.array(all_test), fmt='%.8f')
    
    print("Outputs saved to txt files.")


def process_residual_identity_channle(n_triplets,c, l, num_params, records, conv_idx , conv_layer, in_channels, kernel_size, dead_input_channels):

    A = []
    b = []
    bias = False
    for k in range(l):
        for h in range(0,records[0]['relu_outputs'][conv_idx].shape[2],3):
            for w in range(0,records[0]['relu_outputs'][conv_idx].shape[3],3):
                y_vals = [rec['relu_outputs'][conv_idx][0, c, h, w].item() for rec in records]
                x_vals = [rec['conv_inputs'][0][0, c, h, w].item() for rec in records]
                
                valid_pairs = []
                
                for i in range(3*k, 3*k+3):
                    for j in range(i + 1, 3*k+3):
                        if y_vals[i] > 0 and y_vals[j] > 0:
                            valid_pairs.append((i, j))

                
                
                for pair in valid_pairs:
                    i, j = pair

                    delta_y = y_vals[i] - y_vals[j]
 
                    delta_x = x_vals[i] - x_vals[j]
 
                    delta_F = delta_y - delta_x
                    

                    patch_i = extract_patch(records[i]['conv_inputs'][conv_idx], h, w, conv_layer).cpu().numpy()
                    patch_j = extract_patch(records[j]['conv_inputs'][conv_idx], h, w, conv_layer).cpu().numpy()

                    patch_i = patch_i[0]
                    patch_j = patch_j[0]
 
                    patch_i = np.delete(patch_i, dead_input_channels, axis=0).flatten()
                    patch_j = np.delete(patch_j, dead_input_channels, axis=0).flatten()
                    delta_input = patch_i - patch_j

                    feat = np.append(delta_input, 0.0)  
                    
                    if len(A) > 0:
                        A_old = np.array(A, dtype=np.float64)
                        rank_old = np.linalg.matrix_rank(A_old)                        
                        A_new = np.vstack([A_old, feat])
                        rank_new = np.linalg.matrix_rank(A_new)                        
                        if rank_new > rank_old:
                            try:
                                cond_old = np.linalg.cond(A_old) if rank_old >= 2 else 1.0
                            except np.linalg.LinAlgError:
                                cond_old = np.inf
                            try:
                                cond_new = np.linalg.cond(A_new)
                            except np.linalg.LinAlgError:
                                cond_new = np.inf

                           
                            if np.isinf(cond_old) and np.isinf(cond_new):
                               
                                A.append(feat)
                                b.append(delta_F)
                            elif cond_new < 1e12 and cond_new / (cond_old + 1e-8) < 50:
                                A.append(feat)
                                b.append(delta_F)
                    else:
                        A.append(feat)
                        b.append(delta_F)


                valid_pairs = []
                
                if y_vals[3*k] > 0 and y_vals[3*k+1] <= 0 and y_vals[3*k+2] <= 0:
                    valid_pairs.append(3*k)
                if y_vals[3*k] <= 0 and y_vals[3*k+1] > 0 and y_vals[3*k+2] <= 0:
                    valid_pairs.append(3*k+1)
                if y_vals[3*k] <= 0 and y_vals[3*k+1] <= 0 and y_vals[3*k+2] > 0:
                    valid_pairs.append(3*k+2)

                for i in valid_pairs:
                    bias = True                
                    Fx = y_vals[i] - x_vals[i]
                    patch_i_raw = extract_patch(records[i]['conv_inputs'][conv_idx], h, w, conv_layer).cpu()
                    if patch_i_raw.ndim == 4 and patch_i_raw.shape[0] == 1:
                        patch_i_raw = patch_i_raw[0]  # shape: (C, kh, kw)
                    patch = np.delete(patch_i_raw, dead_input_channels, axis=0).flatten()  
                    feat = np.append(patch, 1.0)
                    A.append(feat)
                    b.append(Fx)

        A_ = np.array(A, dtype=np.float64)
        b_ = np.array(b, dtype=np.float64)
        if(len(A_) == 0 and k >= 10):
            break
        if(len(A_) == 0):
            continue
        A_reg = np.vstack([A_, np.sqrt(0)*np.eye(A_.shape[1])])
        b_reg = np.concatenate([b_, np.zeros(A_.shape[1])])
        A_rank = np.linalg.matrix_rank(A_reg)
        if((A_rank >= num_params - 1) and (bias == False) and k >= n_triplets):
            print(f"channel {c} bias cant be recovered")
            print(f"channel {c} data triplets num = {k}")
            break
        if((A_rank >= num_params) and (bias == True) and k >= n_triplets):
            print(f"channel {c} data triplets num = {k}")
            break
        if((len(A_) == 0) and k > 20):
            break
            
    A = np.array(A, dtype=np.float64)
    b = np.array(b, dtype=np.float64)
    if(len(A) == 0):
        print(f"channel {c} cant be recovered")
        return None     
    else:
        A_reg = np.vstack([A, 1e-3*np.eye(A.shape[1])])
        b_reg = np.concatenate([b, np.zeros(A.shape[1])])
        A_rank = np.linalg.matrix_rank(A_reg)
        if bias:
            None
        else:
            A_reg = A_reg[:,:-1]
            num_params = num_params-1
        if A_rank >= num_params:  
            cond_num = np.linalg.cond(A_reg)
            print(f"channel {c} Condition number: {cond_num:.2e}")
            W, _, _, _ = np.linalg.lstsq(A_reg, b_reg, rcond=1e-6)

            if bias:
                recovered_w_flat = W[:-1]
                reconstructed_bias = W[-1]
            else:
                recovered_w_flat = W
                reconstructed_bias = torch.tensor([0.0])
            reconstructed_weight = np.zeros((in_channels, *kernel_size), dtype=np.float32)

            valid_channels = [i for i in range(in_channels) if i not in dead_input_channels]
            recovered_weight_reshaped = recovered_w_flat.reshape(len(valid_channels), *kernel_size)
            for idx, ch in enumerate(valid_channels):
                reconstructed_weight[ch] = recovered_weight_reshaped[idx]
            return reconstructed_weight, reconstructed_bias
        
        else:
            print(f"channel {c} cant recovered, A_rank = {A_rank}")
            return None
           

def recover_residual_identity_oneconv(n_triplets, records, conv_layer, conv_idx, dead_input_channels = []):
    

    out_channels = conv_layer.out_channels
    in_channels = conv_layer.in_channels
    kernel_size = conv_layer.kernel_size

    num_params = in_channels * kernel_size[0]* kernel_size[1] + 1 - len(dead_input_channels)* kernel_size[0]* kernel_size[1]

    reconstructed_weight = np.zeros((out_channels, in_channels, *kernel_size))
    reconstructed_bias = [0] * out_channels
    
    l = int(len(records)/3)

    results = Parallel(n_jobs=48)(delayed(process_residual_identity_channle)(n_triplets, c, l, num_params, records, conv_idx , conv_layer,in_channels, kernel_size,dead_input_channels) for c in range(out_channels))
    

    #process_residual_identity_channle(6, l, num_params, records, conv_idx , conv_layer,in_channels, kernel_size,dead_input_channels=dead_input_channels)
    num_zeros = 0
    for i in range(len(results)):
        if results[i] == None:
            num_zeros = num_zeros + 1
        else:
            reconstructed_weight[i] = results[i][0]
            reconstructed_bias[i] = results[i][1]

    return reconstructed_weight, reconstructed_bias, num_zeros




def process_residual_block(n_triplets, c, l, num_params, records, idx1, idx2, conv1, conv2,
                           in_channels1, in_channels2, k1, k2,
                           dead_input_channels1, dead_input_channels2):
    A_list = []
    b_list = []
    bias = False
    
    H, W = records[0]['relu_outputs'][idx1].shape[2:4]

    for k in range(l):
        for h in range(H):
            for w in range(W):        
                activations = torch.tensor([rec['relu_outputs'][idx1][0, c, h, w].item() for rec in records])
                valid_pairs = []
                for i in range(3 * k, 3 * k + 3):
                    for j in range(i + 1, 3 * k + 3):
                        if activations[i] > 0 and activations[j] > 0 and abs(activations[i] - activations[j]) > 0.1:
                            valid_pairs.append((i, j))

                if len(valid_pairs) > 20:
                    valid_pairs = np.random.choice(valid_pairs, 20, replace=False)

                for i, j in valid_pairs:
                    rec_i = records[i]
                    rec_j = records[j]

                    patch1_i = extract_patch(rec_i['conv_inputs'][idx1], h, w, conv1).cpu().numpy()
                    patch1_j = extract_patch(rec_j['conv_inputs'][idx1], h, w, conv1).cpu().numpy()
                    patch1_i = np.delete(patch1_i, dead_input_channels1, axis=1)
                    patch1_j = np.delete(patch1_j, dead_input_channels1, axis=1)
                    x1 = (patch1_i - patch1_j).flatten()

                    # shortcut 分支 conv2 的 patch
                    patch2_i = extract_patch(rec_i['conv_inputs'][idx2], h, w, conv2).cpu().numpy()
                    patch2_j = extract_patch(rec_j['conv_inputs'][idx2], h, w, conv2).cpu().numpy()
                    patch2_i = np.delete(patch2_i, dead_input_channels2, axis=1)
                    patch2_j = np.delete(patch2_j, dead_input_channels2, axis=1)
                    x2 = (patch2_i - patch2_j).flatten()

                    feat = np.concatenate([x1, x2,[0.0]])
                    delta_y = activations[i] - activations[j]
                    A_list.append(torch.tensor(feat, dtype=torch.float32))
                    b_list.append(delta_y.item())

                for i in range(3 * k, 3 * k + 3):
                    others = [activations[j] for j in range(3 * k, 3 * k + 3) if j != i]
                    if activations[i] > 0 and all(a <= 0 for a in others):
                        bias = True
                        rec = records[i]
                        patch1 = extract_patch(rec['conv_inputs'][idx1], h, w, conv1).cpu().numpy()
                        patch1 = np.delete(patch1, dead_input_channels1, axis=1).flatten()
                        patch2 = extract_patch(rec['conv_inputs'][idx2], h, w, conv2).cpu().numpy()
                        patch2 = np.delete(patch2, dead_input_channels2, axis=1).flatten()
                        feat = np.concatenate([patch1,  patch2, [1.0]])
                        A_list.append(torch.tensor(feat, dtype=torch.float32))
                        b_list.append(activations[i].item())

        if len(A_list) == 0 and k >= 10:
            print(f"channel {c} cant be recovered")
            return None
        elif len(A_list) == 0 and k <10:
            continue
        A_tmp = torch.stack(A_list).numpy()
        b_tmp = np.array(b_list)
        A_rank = np.linalg.matrix_rank(A_tmp)

        if A_rank >= num_params and bias and k >= n_triplets :
            print(f"channel {c} data triplets num = {k}")
            break
        elif A_rank >= num_params - 1 and not bias and k >= n_triplets:
            print(f"channel {c} bias cant be recovered")
            print(f"channel {c} data triplets num = {k}")
            break

    if len(A_list) == 0:
        print(f"channel {c} cant be recovered")
        return None
    
    A = np.stack([a.numpy() if isinstance(a, torch.Tensor) else a for a in A_list])

    b = np.array(b_list, dtype=np.float64)

    if not bias:
        num_params -= 1
        A = A[:,:-1]

    # A = np.stack([a.numpy() for a in A_list])
    # b = np.array(b_list)

    A_reg = np.vstack([A, 1e-3 * np.eye(A.shape[1])])
    b_reg = np.concatenate([b, np.zeros(A.shape[1])])

    cond_num = np.linalg.cond(A_reg)
    print(f"channel {c} Condition number: {cond_num:.2e}")

    try:
        W = np.linalg.solve(A_reg.T @ A_reg, A_reg.T @ b_reg)
    except np.linalg.LinAlgError:
        W, *_ = np.linalg.lstsq(A_reg, b_reg, rcond=1e-5)

    in1 = len([i for i in range(in_channels1) if i not in dead_input_channels1])
    in2 = len([i for i in range(in_channels2) if i not in dead_input_channels2])
    W1_len = in1 * k1[0] * k1[1]
    W2_len = in2 * k2[0] * k2[1]

    W1_flat = W[:W1_len]
    if bias:
        W2_flat = W[W1_len:-1]
        b1 = W[-1] 
        b2 = np.zeros(1)
    else:
        W2_flat = W[W1_len:]
        b1 = np.zeros(1) 
        b2 = np.zeros(1)


    weight1 = np.zeros((in_channels1, *k1), dtype=np.float32)
    valid1 = [i for i in range(in_channels1) if i not in dead_input_channels1]
    W1_reshaped = W1_flat.reshape(len(valid1), *k1)
    for idx, ch in enumerate(valid1):
        weight1[ch] = W1_reshaped[idx]

    weight2 = np.zeros((in_channels2, *k2), dtype=np.float32)
    valid2 = [i for i in range(in_channels2) if i not in dead_input_channels2]
    W2_reshaped = W2_flat.reshape(len(valid2), *k2)
    for idx, ch in enumerate(valid2):
        weight2[ch] = W2_reshaped[idx]

    return weight1, b1, weight2, b2



def recover_residual_block(n_triplets, records, conv_layers, conv_idxs, dead_input_channels1 = [], dead_input_channels2 = []):


    conv1, conv2 = conv_layers
    idx1, idx2 = conv_idxs    
    in_channels1 = conv1.in_channels
    in_channels2 = conv2.in_channels
    out_channels = conv1.out_channels
    k1 = conv1.kernel_size
    k2 = conv2.kernel_size

    weight1 = np.zeros((out_channels, in_channels1, *k1))
    bias1 = np.zeros(conv1.out_channels)

    weight2 = np.zeros((out_channels, in_channels2, *k2))
    bias2 = np.zeros(out_channels)
    num_zeros = 0
    num_params = in_channels1 * k1[0] * k1[1] + in_channels2 * k2[0] * k2[1] + 1 - len(dead_input_channels1) *  k1[0] * k1[1] - k2[0] * k2[1] * len(dead_input_channels2)

    l = int(len(records)/3)

    # if num_params == 305:
    #     results = Parallel(n_jobs=36)(delayed(process_residual_block)(c, l, num_params, records, idx1, idx2 , conv1 ,conv2 ,in_channels1, in_channels2, k1, k2) for c in [8,9,13,17,25])

    # else:
    results = Parallel(n_jobs=48)(delayed(process_residual_block)(n_triplets, c, l, num_params, records, idx1, idx2 , conv1 ,conv2 ,in_channels1, in_channels2, k1, k2,dead_input_channels1,dead_input_channels2) for c in range(out_channels))
    
    #process_residual_block(5, l, num_params, records, idx1, idx2 , conv1 ,conv2 ,in_channels1, in_channels2, k1, k2,dead_input_channels1,dead_input_channels2)
    num_zeros = 0
    for i in range(len(results)):
        if results[i] == None:
            num_zeros = num_zeros + 1
        else:
            weight1[i] = results[i][0]
            bias1[i] = results[i][1]
            weight2[i] = results[i][2]
            bias2[i] = results[i][3]
            

    return weight1,bias1 , weight2, bias2, num_zeros





def compare_relu_pre_activations(model_path_orig,
                                 model_class,
                                 model_path_reconstructed,
                                 output_dir="relu_pre_diff_channelwise",
                                 input_shape=(1, 3, 32, 32),
                                 device=None):
    """
    比较两个模型中 ReLU 前的激活输出差异（按通道保存差值），针对 CIFAR100。
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 加载模型
    model = model_class(num_classes=100).to(device)
    model.load_state_dict(torch.load(model_path_orig, map_location='cpu'))
    model.eval()

    reconstructed_model = model_class(num_classes=100).to(device)
    reconstructed_model.load_state_dict(torch.load(model_path_reconstructed, map_location='cpu'))
    reconstructed_model.eval()

    # Hook 容器
    relu_inputs_orig = {}
    relu_inputs_recon = {}

    def get_relu_pre_hook(name, storage_dict):
        def hook(module, input):
            storage_dict[name] = input[0].detach().cpu()
        return hook

    # 注册 hook
    for name, module in model.named_modules():
        if isinstance(module, nn.ReLU):
            module.register_forward_pre_hook(get_relu_pre_hook(f"orig_{name}", relu_inputs_orig))

    for name, module in reconstructed_model.named_modules():
        if isinstance(module, nn.ReLU):
            module.register_forward_pre_hook(get_relu_pre_hook(f"recon_{name}", relu_inputs_recon))

    # 前向传播，随机输入 CIFAR100 分布
    dummy_input = torch.randn(*input_shape).to(device)
    # CIFAR100 官方归一化
    mean = torch.tensor([0.5071, 0.4867, 0.4408], device=device).view(1,3,1,1)
    std  = torch.tensor([0.2675, 0.2565, 0.2761], device=device).view(1,3,1,1)
    dummy_input = (dummy_input - mean) / std

    with torch.no_grad():
        model(dummy_input)
        reconstructed_model(dummy_input)

    os.makedirs(output_dir, exist_ok=True)

    for key1 in relu_inputs_orig:
        key2 = key1.replace("orig_", "recon_")
        if key2 not in relu_inputs_recon:
            print(f"⚠️ No match for {key1}")
            continue

        out1 = relu_inputs_orig[key1].squeeze(0)
        out2 = relu_inputs_recon[key2].squeeze(0)
        diff = out1 - out2
        mse_per_channel = (diff ** 2).mean(dim=(1,2))
        mse_layer = mse_per_channel.mean().item()

        layer_name = key1.replace("orig_", "")
        file_path = os.path.join(output_dir, f"{layer_name}.txt")
        with open(file_path, "w") as f:
            for c in range(mse_per_channel.shape[0]):
                f.write(f"{mse_per_channel[c].item():.6e}\n")
            f.write(f"=== Layer MSE: {mse_layer:.6e} ===\n")


def compare_relu_post_activations(model_path_orig,
                                  model_class,
                                  model_path_reconstructed,
                                  output_dir="relu_post_diff_channelwise",
                                  input_shape=(1, 3, 32, 32),
                                  device=None):
    """
    比较两个模型中 ReLU 后的激活输出差异（按通道保存差值 + 每层均方差），针对 CIFAR100。
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 加载模型
    model = model_class(num_classes=100).to(device)
    model.load_state_dict(torch.load(model_path_orig, map_location='cpu'))
    model.eval()

    reconstructed_model = model_class(num_classes=100).to(device)
    reconstructed_model.load_state_dict(torch.load(model_path_reconstructed, map_location='cpu'))
    reconstructed_model.eval()

    relu_outputs_orig = {}
    relu_outputs_recon = {}

    def get_relu_post_hook(name, storage_dict):
        def hook(module, input, output):
            storage_dict[name] = output.detach().cpu()
        return hook

    for name, module in model.named_modules():
        if isinstance(module, nn.ReLU):
            module.register_forward_hook(get_relu_post_hook(f"orig_{name}", relu_outputs_orig))

    for name, module in reconstructed_model.named_modules():
        if isinstance(module, nn.ReLU):
            module.register_forward_hook(get_relu_post_hook(f"recon_{name}", relu_outputs_recon))

    # 前向传播，随机输入 CIFAR100 分布
    dummy_input = torch.randn(*input_shape).to(device)
    mean = torch.tensor([0.5071, 0.4867, 0.4408], device=device).view(1,3,1,1)
    std  = torch.tensor([0.2675, 0.2565, 0.2761], device=device).view(1,3,1,1)
    dummy_input = (dummy_input - mean) / std

    with torch.no_grad():
        model(dummy_input)
        reconstructed_model(dummy_input)

    os.makedirs(output_dir, exist_ok=True)
    summary_path = os.path.join(output_dir, "summary.txt")
    with open(summary_path, "w") as summary_file:
        for key1 in relu_outputs_orig:
            key2 = key1.replace("orig_", "recon_")
            if key2 not in relu_outputs_recon:
                print(f"⚠️ No match for {key1}")
                continue

            out1 = relu_outputs_orig[key1].squeeze(0)
            out2 = relu_outputs_recon[key2].squeeze(0)
            diff = out1 - out2

            mse_per_channel = (diff ** 2).mean(dim=(1,2))
            mse_layer = mse_per_channel.mean().item()

            layer_name = key1.replace("orig_", "")
            file_path = os.path.join(output_dir, f"{layer_name}.txt")
            with open(file_path, "w") as f:
                for c in range(mse_per_channel.shape[0]):
                    f.write(f"{mse_per_channel[c].item():.6e}\n")
                f.write(f"=== Layer MSE: {mse_layer:.6e} ===\n")

            summary_file.write(f"{layer_name}: {mse_layer:.6e}\n")
            print(f"✅ Saved: {file_path}  [shape={diff.shape}]")
        print(f"📊 Summary saved: {summary_path}")






def forward_one_layer(input_tensor, conv_weight, stride=(1,1), conv_bias=None ,bn_layer=None):

    conv = nn.Conv2d(
        in_channels=conv_weight.shape[1],
        out_channels=conv_weight.shape[0],
        kernel_size=conv_weight.shape[2:],
        stride=stride,
        padding=model.conv1.padding,
        bias=(conv_bias is not None)
    )
    conv.weight.data = torch.tensor(conv_weight).float()
    
    if conv_bias is not None:
        conv.bias.data = torch.tensor(conv_bias).float()


    with torch.no_grad():
        x = conv(input_tensor)
        if bn_layer is not None:
            x = bn_layer(x)
        x = F.relu(x)
    return x


def forward_resblock_simple(
    x1, x2,                             
    conv2_w,stride1=(1,1), conv2_b=None,           
    shortcut_conv_w=None, shortcut_conv_b=None,  
    bn_layer=None, shortcut_bn_layer=None,       
    stride2=1, use_relu=True
):

    
    conv2 = nn.Conv2d(
        in_channels=conv2_w.shape[1],
        out_channels=conv2_w.shape[0],
        kernel_size=3, stride=stride1, padding=1,
        bias=(conv2_b is not None)
    )
    conv2.weight.data = torch.tensor(conv2_w).float()
    if conv2_b is not None:
        conv2.bias.data = torch.tensor(conv2_b).float()

    
    if shortcut_conv_w is not None:
        shortcut_conv = nn.Conv2d(
            in_channels=shortcut_conv_w.shape[1],
            out_channels=shortcut_conv_w.shape[0],
            kernel_size=1, stride=stride2, padding=0,
            bias=(shortcut_conv_b is not None)
        )
        shortcut_conv.weight.data = torch.tensor(shortcut_conv_w).float()
        if shortcut_conv_b is not None:
            shortcut_conv.bias.data = torch.tensor(shortcut_conv_b).float()

    
    with torch.no_grad():
        out = conv2(x1)
        if bn_layer is not None:
            out = bn_layer(out)

        if shortcut_conv_w is not None:
            shortcut = shortcut_conv(x2)
            if shortcut_bn_layer is not None:
                shortcut = shortcut_bn_layer(shortcut)
        else:
            shortcut = x2

        out += shortcut
        if use_relu:
            out = F.relu(out)

    return out



def recover_fc_weights(inputs, outputs):

    device = inputs.device
    dtype = inputs.dtype
    N = inputs.size(0)

    
    ones = torch.ones(N, 1, device=device, dtype=dtype)
    X_aug = torch.cat([inputs, ones], dim=1)  

    
    W_aug = torch.linalg.lstsq(X_aug, outputs).solution 

    
    weight = W_aug[:-1, :].T.contiguous()  
    bias = W_aug[-1, :].contiguous()       
    return weight, bias










def recover_model(model, reconstructed_model, n_triplets):

    inputs = generate_cifar100_like_inputs(300, device="cpu")
    records_new = collect_activations(model,input_tensors=inputs, num_samples=300)

    print("recover 1-st layer")
    conv1_weights, conv1_bias, num_zeros1 = recover_general_conv_layer(n_triplets,
                                    records=records_new,
                                    conv_layer=model.conv1,
                                    conv_idx=0,
                                    dead_input_channels=[]
                                )
    
    
    with torch.no_grad():
        reconstructed_model.conv1.weight.copy_(torch.tensor(conv1_weights).float())
        reconstructed_model.bn1.bias.copy_(torch.tensor(conv1_bias).float())
    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
    
    # --------------------------layer1[0].conv1---------------------------------------------
    
    

        
    

    input_tensors = []
    for i in records_new:
        input_tensors.append(i['relu_outputs'][0])
    block1_1_records_org = collect_block_activations(model = model,
                                               layer_name = "layer1",
                                               input_tensors= input_tensors,
                                               block_idx=0,
                                               )
    
    output_tensors_re = []

    block_input_tensor = []

    for i in records_new:
        a = forward_one_layer(i['conv_inputs'][0],conv_weight=reconstructed_model.conv1.weight,conv_bias=reconstructed_model.bn1.bias)
        output_tensors_re.append(a)
        block_input_tensor.append(a)
    

    for i in range(len(block1_1_records_org)):
        block1_1_records_org[i]["conv_inputs"][0] = output_tensors_re[i]
    



    print("recover 2-nd layer")
    layer1_1_conv1, layer1_1_conv1_bias, num_zeros = recover_general_conv_layer(n_triplets,
        records=block1_1_records_org,
        conv_layer=model.layer1[0].conv1,
        conv_idx=0,    
        dead_input_channels= []
    )

    # # #print(num_zeros)
    with torch.no_grad():
        reconstructed_model.layer1[0].conv1.weight.copy_(torch.tensor(layer1_1_conv1).float())
        reconstructed_model.layer1[0].bn1.bias.copy_(torch.tensor(layer1_1_conv1_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
    #------------------------------layer1[0].conv2----------------------------------------



    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_one_layer(output_tensors_re[i],reconstructed_model.layer1[0].conv1.weight,conv_bias=reconstructed_model.layer1[0].bn1.bias))



    for i in range(len(block1_1_records_org)):
        block1_1_records_org[i]["conv_inputs"][1] = output_tensors_re[i]
    

    print("recover 3-th layer")
    layer1_1_conv2, layer1_1_conv2_bias, num_zeros = recover_residual_identity_oneconv(n_triplets,
        records=block1_1_records_org,
        conv_layer=model.layer1[0].conv2, 
        conv_idx=1,
        dead_input_channels=[]
    ) 

    with torch.no_grad():
        reconstructed_model.layer1[0].conv2.weight.copy_(torch.tensor(layer1_1_conv2).float())
        reconstructed_model.layer1[0].bn2.bias.copy_(torch.tensor(layer1_1_conv2_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
#     #--------------------------------------------------------

    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_resblock_simple(x1=output_tensors_re[i], x2=block_input_tensor[i], conv2_w=reconstructed_model.layer1[0].conv2.weight,
                                                      conv2_b = reconstructed_model.layer1[0].bn2.bias))




    input_tensors = []
    for i in block1_1_records_org:
        input_tensors.append(i['relu_outputs'][1])
    block1_2_records_org = collect_block_activations(model = model,
                                               layer_name = "layer1",
                                               input_tensors= input_tensors,
                                               block_idx=1,
                                               )
    

    
    block_input_tensor = []
    for i in range(len(block1_2_records_org)):
        block1_2_records_org[i]["conv_inputs"][0] = output_tensors_re[i]
        block_input_tensor.append(output_tensors_re[i])

    print("recover 4-th layer")
    layer1_2_conv1, layer1_2_conv1_bias, num_zeros = recover_general_conv_layer(n_triplets,
        records=block1_2_records_org,
        conv_layer=model.layer1[1].conv1,
        conv_idx=0,    
        dead_input_channels=[]
    )
    with torch.no_grad():
        reconstructed_model.layer1[1].conv1.weight.copy_(torch.tensor(layer1_2_conv1).float())
        reconstructed_model.layer1[1].bn1.bias.copy_(torch.tensor(layer1_2_conv1_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')

#    # ---------------------------------layer1[1].conv2-----------------------------------------

    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_one_layer(output_tensors_re[i],reconstructed_model.layer1[1].conv1.weight, conv_bias=reconstructed_model.layer1[1].bn1.bias))



    for i in range(len(block1_2_records_org)):
        block1_2_records_org[i]["conv_inputs"][1] = output_tensors_re[i]
    

    tensor_list = []
    for i in range(len(block1_2_records_org)):
        tensor_list.append(block1_2_records_org[i]["relu_inputs"][1])
    



    print("recover 5-th layer")
    layer1_2_conv2, layer1_2_conv2_bias,num_zeros = recover_residual_identity_oneconv(n_triplets,
        records=block1_2_records_org,
        conv_layer=model.layer1[1].conv2, 
        conv_idx=1,    
        dead_input_channels=[]
    )

    with torch.no_grad():
        reconstructed_model.layer1[1].conv2.weight.copy_(torch.tensor(layer1_2_conv2).float())
        reconstructed_model.layer1[1].bn2.bias.copy_(torch.tensor(layer1_2_conv2_bias).float())



#     #----------------------layer2----------------
    input_tensors = []
    for i in block1_2_records_org:
        input_tensors.append(i['relu_outputs'][1])
    block2_1_records_org = collect_block_activations(model = model,
                                               layer_name = "layer2",
                                               input_tensors= input_tensors,
                                               block_idx=0,
                                               )
    

    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_resblock_simple(x1=output_tensors_re[i], x2=block_input_tensor[i], conv2_w=reconstructed_model.layer1[1].conv2.weight,
                                                      conv2_b = reconstructed_model.layer1[1].bn2.bias))

    block_input_tensor = []
    for i in range(len(output_tensors_re)):
        block_input_tensor.append(output_tensors_re[i])
        block2_1_records_org[i]["conv_inputs"][0] = output_tensors_re[i]
        block2_1_records_org[i]["conv_inputs"][2] = output_tensors_re[i]
    
    print("recover 6-th layer")
    layer2_1_conv1, layer2_1_conv1_bias, num_zeros = recover_general_conv_layer(n_triplets,
        records=block2_1_records_org,
        conv_layer=model.layer2[0].conv1,
        conv_idx=0,    
        dead_input_channels=[]
    )

    #print(num_zeros)
    with torch.no_grad():
        reconstructed_model.layer2[0].conv1.weight.copy_(torch.tensor(layer2_1_conv1).float())
        reconstructed_model.layer2[0].bn1.bias.copy_(torch.tensor(layer2_1_conv1_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
#     #------------------------------layer1[0].conv2----------------------------------------


    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_one_layer(output_tensors_re[i],reconstructed_model.layer2[0].conv1.weight, stride=(2,2), conv_bias=reconstructed_model.layer2[0].bn1.bias))

    for i in range(len(output_tensors_re)):
        block2_1_records_org[i]["conv_inputs"][1] = output_tensors_re[i]
    
    

    print("recover 7-th layer")
    layer2_1_conv2, layer2_1_conv2_bias,layer2_1_conv3, layer2_1_conv3_bias, num_zeros = recover_residual_block(n_triplets,
        records=block2_1_records_org,
        conv_layers=[model.layer2[0].conv2, model.layer2[0].shortcut[0]], 
        conv_idxs=[1,2],    
        dead_input_channels1=[],
        dead_input_channels2 = [],
    )


    with torch.no_grad():
        reconstructed_model.layer2[0].conv2.weight.copy_(torch.tensor(layer2_1_conv2).float())
        reconstructed_model.layer2[0].bn2.bias.copy_(torch.tensor(layer2_1_conv2_bias).float())
        reconstructed_model.layer2[0].shortcut[0].weight.copy_(torch.tensor(layer2_1_conv3).float())
        #reconstructed_model.layer2[0].shortcut[0].bias.copy_(torch.tensor(layer2_1_conv3_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
#     #--------------------------------------------------------



    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_resblock_simple(x1=block2_1_records_org[i]["conv_inputs"][1], x2=block2_1_records_org[i]["conv_inputs"][2], conv2_w=reconstructed_model.layer2[0].conv2.weight,
                                                      stride1=(1,1),conv2_b = reconstructed_model.layer2[0].bn2.bias,
                                                      shortcut_conv_w=reconstructed_model.layer2[0].shortcut[0].weight,stride2=(2,2)))
    


    input_tensors = []
    for i in block2_1_records_org:
        input_tensors.append(i['relu_outputs'][1])
    block2_2_records_org = collect_block_activations(model = model,
                                               layer_name = "layer2",
                                               input_tensors= input_tensors,
                                               block_idx=1,
                                               )
    
    for i in range(len(output_tensors_re)):
        block2_2_records_org[i]["conv_inputs"][0] = output_tensors_re[i]
    


    print("recover 8-th layer")
    layer2_2_conv1, layer2_2_conv1_bias, num_zeros = recover_general_conv_layer(n_triplets,
        records=block2_2_records_org,
        conv_layer=model.layer2[1].conv1,
        conv_idx=0,    
        dead_input_channels=[],
    )


    with torch.no_grad():
        reconstructed_model.layer2[1].conv1.weight.copy_(torch.tensor(layer2_2_conv1).float())
        reconstructed_model.layer2[1].bn1.bias.copy_(torch.tensor(layer2_2_conv1_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
#    # ---------------------------------layer1[1].conv2-----------------------------------------

    for i in range(len(output_tensors_re)):
        output_tensors_re[i]=(forward_one_layer(output_tensors_re[i],reconstructed_model.layer2[1].conv1.weight, stride=(1,1), conv_bias=reconstructed_model.layer2[1].bn1.bias))

    for i in range(len(output_tensors_re)):
        block2_2_records_org[i]["conv_inputs"][1] = output_tensors_re[i]
    
    print("recover 9-th layer")
    layer2_2_conv2, layer2_2_conv2_bias,num_zeros = recover_residual_identity_oneconv(n_triplets,
        records=block2_2_records_org,
        conv_layer=model.layer2[1].conv2, 
        conv_idx=1,    
        dead_input_channels=[]
    )


    with torch.no_grad():
        reconstructed_model.layer2[1].conv2.weight.copy_(torch.tensor(layer2_2_conv2).float())
        reconstructed_model.layer2[1].bn2.bias.copy_(torch.tensor(layer2_2_conv2_bias).float())
    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
# #-----------------------layer3-------------------------------------------------



    for i in range(len(output_tensors_re)):
        output_tensors_re[i] = forward_resblock_simple(x1=block2_2_records_org[i]["conv_inputs"][1], x2=block2_2_records_org[i]["conv_inputs"][0], conv2_w=reconstructed_model.layer2[1].conv2.weight,conv2_b=reconstructed_model.layer2[1].bn2.bias)
    
    
    input_tensors = []
    for i in block2_2_records_org:
        input_tensors.append(i['relu_outputs'][1])
    block3_1_records_org = collect_block_activations(model = model,
                                               layer_name = "layer3",
                                               input_tensors= input_tensors,
                                               block_idx=0,
                                               )
    
    for i in range(len(output_tensors_re)):
        block3_1_records_org[i]['conv_inputs'][0] = output_tensors_re[i]
        block3_1_records_org[i]['conv_inputs'][2] = output_tensors_re[i]


    print("recover 10-th layer")
    layer3_1_conv1, layer3_1_conv1_bias, num_zeros = recover_general_conv_layer(n_triplets,
        records=block3_1_records_org,
        conv_layer=model.layer3[0].conv1,
        conv_idx=0,    
        dead_input_channels=[]
    )

    #print(num_zeros)
    with torch.no_grad():
        reconstructed_model.layer3[0].conv1.weight.copy_(torch.tensor(layer3_1_conv1).float())
        reconstructed_model.layer3[0].bn1.bias.copy_(torch.tensor(layer3_1_conv1_bias).float())
    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
    #------------------------------------------------------------------------------------

    for i in range(len(output_tensors_re)):
        output_tensors_re[i] = forward_one_layer(output_tensors_re[i],conv_weight=reconstructed_model.layer3[0].conv1.weight,conv_bias=reconstructed_model.layer3[0].bn1.bias,stride=(2,2))

    for i in range(len(output_tensors_re)):
        block3_1_records_org[i]["conv_inputs"][1] = output_tensors_re[i]


    print("recover 11-th layer")
    layer3_1_conv2, layer3_1_conv2_bias,layer3_1_conv3, layer3_1_conv3_bias, num_zeros = recover_residual_block(n_triplets,
        records=block3_1_records_org,
        conv_layers=[model.layer3[0].conv2, model.layer3[0].shortcut[0]], 
        conv_idxs=[1,2], 
        dead_input_channels1=[],
        dead_input_channels2=[]
    )

    with torch.no_grad():
        reconstructed_model.layer3[0].conv2.weight.copy_(torch.tensor(layer3_1_conv2).float())
        reconstructed_model.layer3[0].bn2.bias.copy_(torch.tensor(layer3_1_conv2_bias).float())
        reconstructed_model.layer3[0].shortcut[0].weight.copy_(torch.tensor(layer3_1_conv3).float())
    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')


    for i in range(len(output_tensors_re)):
        output_tensors_re[i] = forward_resblock_simple(x1=block3_1_records_org[i]["conv_inputs"][1], x2=block3_1_records_org[i]["conv_inputs"][2],conv2_w=reconstructed_model.layer3[0].conv2.weight,
                                                       stride1=(1,1),conv2_b=reconstructed_model.layer3[0].bn2.bias,shortcut_conv_w=reconstructed_model.layer3[0].shortcut[0].weight,stride2=(2,2))




    input_tensors = []
    for i in block3_1_records_org:
        input_tensors.append(i['relu_outputs'][1])
    block3_2_records_org = collect_block_activations(model = model,
                                               layer_name = "layer3",
                                               input_tensors= input_tensors,
                                               block_idx=1,
                                               )
    
    for i in range(len(output_tensors_re)):
        block3_2_records_org[i]["conv_inputs"][0] = output_tensors_re[i]
    

    print("recover 12-th layer")
    layer3_2_conv1, layer3_2_conv1_bias, num_zeros = recover_general_conv_layer(n_triplets,
        records=block3_2_records_org,
        conv_layer=model.layer3[1].conv1,
        conv_idx=0,    
        dead_input_channels=[]
    )


    with torch.no_grad():
        reconstructed_model.layer3[1].conv1.weight.copy_(torch.tensor(layer3_2_conv1).float())
        reconstructed_model.layer3[1].bn1.bias.copy_(torch.tensor(layer3_2_conv1_bias).float())

    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')
#    # ---------------------------------layer1[1].conv2-----------------------------------------


    for i in range(len(output_tensors_re)):
        output_tensors_re[i] = forward_one_layer(output_tensors_re[i],conv_weight=reconstructed_model.layer3[1].conv1.weight,stride=(1,1),conv_bias=reconstructed_model.layer3[1].bn1.bias)


    for i in range(len(output_tensors_re)):
        block3_2_records_org[i]["conv_inputs"][1] = output_tensors_re[i]




    print("recover 13-th layer")
    layer3_2_conv2, layer3_2_conv2_bias,num_zeros = recover_residual_identity_oneconv(n_triplets,
        records=block3_2_records_org,
        conv_layer=model.layer3[1].conv2, 
        conv_idx=1,    # 
        dead_input_channels=[]
    )

    with torch.no_grad():
        reconstructed_model.layer3[1].conv2.weight.copy_(torch.tensor(layer3_2_conv2).float())
        reconstructed_model.layer3[1].bn2.bias.copy_(torch.tensor(layer3_2_conv2_bias).float())
    
    reconstructed_model.fc.load_state_dict(model.fc.state_dict())
    #torch.save(reconstructed_model.state_dict(), 'reconstructed_modelplus_layers_new.pth')


    for i in range(len(output_tensors_re)):
        output_tensors_re[i] = forward_resblock_simple(x1=block3_2_records_org[i]["conv_inputs"][1], x2=block3_2_records_org[i]["conv_inputs"][0], conv2_w=reconstructed_model.layer3[1].conv2.weight,conv2_b=reconstructed_model.layer3[1].bn2.bias)
     


    fc_input = []
    fc_output = []

    for i in range(len(output_tensors_re)):
        x_in = model.avg_pool(block3_2_records_org[i]["relu_outputs"][1])  # -> [B, 64, 1, 1]
        x_in = torch.flatten(x_in, 1)  # -> [B, 64]
        #fc_input.append(x_in)
        x_out = model.fc(x_in)         # -> [B, 10]
        fc_output.append(x_out)


        x_in = model.avg_pool(output_tensors_re[i])  # -> [B, 64, 1, 1]
        x_in = torch.flatten(x_in, 1)  # -> [B, 64]
        fc_input.append(x_in)
        

    X = torch.cat(fc_input, dim=0)   # [N_total, 64]
    Y = torch.cat(fc_output, dim=0)  # [N_total, 10]

    fc_weights, fc_bias = recover_fc_weights(X, Y)
    #fc_weights, fc_bias = recover_fc_weights(fc_input[0], fc_output[0])

    reconstructed_model.fc.weight.data = fc_weights
    reconstructed_model.fc.bias.data = fc_bias

    torch.save(reconstructed_model.state_dict(), "reconstructed_modelplus_layers_new_0.pth")






if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MiniResNetPlus(num_classes=100)
    model.load_state_dict(torch.load("checkpoints_relu_prob_sigmoid/model_epoch20.pth", map_location='cpu'))
    model.eval()  

    

    reconstructed_model = MiniResNetPlus(num_classes=100)
    reconstructed_model.load_state_dict(torch.load("reconstructed_modelplus_layers_new_30.pth", map_location=device))
    reconstructed_model.to(device)
    reconstructed_model.eval()

    compare_relu_pre_activations(
        model_path_orig="checkpoints_relu_prob_sigmoid/model_epoch20.pth",
        model_class=MiniResNetPlus,
        model_path_reconstructed="reconstructed_modelplus_layers_new_30.pth"
    )
    compare_relu_post_activations(
        model_path_orig="checkpoints_relu_prob_sigmoid/model_epoch20.pth",
        model_class=MiniResNetPlus,
        model_path_reconstructed="reconstructed_modelplus_layers_new_30.pth"
    )
    
    recover_model(model, reconstructed_model,n_triplets=0)


