import matplotlib
matplotlib.use('Agg')
import copy
import torch
import random
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import numpy as np

def get_reset_probability(current_round, m, p, initial_mu):
    """计算当前轮次的Reset概率"""
    if current_round <= m:
        return initial_mu
    elif current_round <= m + p:
        # 线性递减
        progress = (current_round - m) / p
        return initial_mu * (1 - progress)
    else:
        return 0.0

    
 


def zero_out_next_layer_weights(model, reset_percent):
    def get_conv_linear_layers(module, layers):
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            layers.append(module)
        elif hasattr(module, 'children') and len(list(module.children())) > 0:
            for child in module.children():
                get_conv_linear_layers(child, layers)
        return layers

    layers = get_conv_linear_layers(model, [])
    #print(f"Total Conv/Linear layers: {len(layers)}")

    for i in range(len(layers) - 1):
        current_layer = layers[i]
        next_layer = layers[i + 1]
        
        if isinstance(current_layer, nn.Linear):
            num_neurons = current_layer.out_features
        elif isinstance(current_layer, nn.Conv2d):
            num_neurons = current_layer.out_channels

        num_to_zero_out = max(1, int(np.ceil(reset_percent* num_neurons)))
        indices_to_zero_out = np.random.choice(num_neurons, num_to_zero_out, replace=False)

        # print(f"Layer {i} zeros out {num_to_zero_out} neurons: {indices_to_zero_out}")

        if isinstance(next_layer, nn.Linear):
            if max(indices_to_zero_out) < next_layer.weight.data.shape[1]:
                next_layer.weight.data[:, indices_to_zero_out] = 0.0
            # else:
            #     print(f"Skipping layer {i+1} due to index out of bounds")
        elif isinstance(next_layer, nn.Conv2d):
            if max(indices_to_zero_out) < next_layer.weight.data.shape[1]:
                next_layer.weight.data[:, indices_to_zero_out, :, :] = 0.0
            # else:
            #     print(f"Skipping layer {i+1} due to index out of bounds")

    # model.to(device)




def zero_out_model_params(model, percent_set_zero_base):
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d):
            # 卷积层设定较低的置零比例
            percent_set_zero = percent_set_zero_base*  1  # 比如基础置零比例的一半
        elif isinstance(layer, nn.Linear):
            # 全连接层设定较高的置零比例
            percent_set_zero = percent_set_zero_base *0 # 比如基础置零比例的1.5倍
            # percent_set_zero = 1 # 比如基础置零比例的1.5倍
        else:
            # 对于其他层，我们可以保持原有的置零比例不变
            percent_set_zero = percent_set_zero_base*0
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            if hasattr(layer, 'weight'):
                param = layer.weight
                total_params = param.numel()
                num_zero_params = int(total_params * percent_set_zero)
                mask = torch.ones_like(param)
                zero_indices = np.random.choice(total_params, num_zero_params, replace=False)
                mask.view(-1)[zero_indices] = 0
                param.data.mul_(mask)
    return model

def reset_model_params_with_adaptive_noise(model, percent_set_zero_base):

    for layer in model.modules():
        if isinstance(layer, nn.Conv2d):
            # 卷积层设定较低的置零比例
            percent_set_zero = percent_set_zero_base * 1   # 比如基础置零比例的一半
        elif isinstance(layer, nn.Linear):
            # 全连接层设定较高的置零比例
            percent_set_zero = percent_set_zero_base * 1 # 比如基础置零比例的1.5倍
        else:
            # 对于其他层，我们可以保持原有的置零比例不变
            continue  # Skip layers that are not Conv2d or Linear

        if hasattr(layer, 'weight'):
            param = layer.weight
            mean = param.data.mean()  # 计算权重的均值
            # mean = 0
            std = param.data.std()    # 计算权重的标准差
            total_params = param.numel()
            num_reset_params = int(total_params * percent_set_zero)
            noise = torch.normal(mean, std, size=(num_reset_params,)).to(param.device)
            mask = torch.ones_like(param).view(-1)
            zero_indices = np.random.choice(total_params, num_reset_params, replace=False)
            mask[zero_indices] = 0
            param.data = param.data * mask.view(param.size())
            param.data.view(-1)[zero_indices] = noise

    return model

from pytorch_msssim import ssim
import torchvision.utils as vutils
import os
from datetime import datetime





def get_test_image(dataset, seed=42):
    """
    Get a test image from the dataset
    Args:
        dataset: CIFAR10 dataset
        seed: random seed
    Returns:
        original image and normalized image
    """
    torch.manual_seed(seed)
    idx = torch.randint(len(dataset), (1,)).item()
    img, label = dataset[idx]
    
    # Denormalize the image for display
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3,1,1)
    std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3,1,1)
    orig_img = img * std + mean
    
    return orig_img, img, label

def reset_conv_kernels_stair(model, current_iter, transition_period, reset_ratio=0.03, scale_factor=0.5):
    """
    Stop resetting convolutional kernels in a stepwise manner from shallow to deep layers within the specified number of iterations
    
    Args:
        model: PyTorch model
        current_iter: current training iteration
        transition_period: total number of iterations required to complete resetting all layers
        reset_ratio: reset ratio, range [0,1]
        scale_factor: scaling factor, used to control the magnitude of new parameters
        
    Returns:
        reset_count: number of convolutional layers that have been reset
    """
    if not 0 <= reset_ratio <= 1:
        raise ValueError("Reset ratio must be between 0 and 1")
        
    # Get all convolutional layers in the network
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]
    total_depth = len(conv_layers)
    reset_count = 0
    
    # Iterate through each convolutional layer
    for depth, layer in enumerate(conv_layers):
        # Calculate the iteration threshold for this layer to stop resetting
        stop_iter = (depth + 1) * (transition_period / total_depth)
        
        # If the current iteration has exceeded the stop threshold for this layer, skip this layer
        if current_iter >= stop_iter:
            continue
            
        # Get the current statistics of the convolutional layer parameters
        current_mean = layer.weight.data.mean().item()
        current_std = layer.weight.data.std().item()
        
        # Get the number of convolutional kernels
        num_kernels = layer.weight.shape[0]
        num_reset = int(num_kernels * reset_ratio)
        
        if num_reset > 0:
            # Randomly select the indices of the kernels to be reset
            reset_indices = random.sample(range(num_kernels), num_reset)
            
            # Reset the selected kernels
            for idx in reset_indices:
                # Generate new parameters with a similar distribution to the original
                new_kernel = torch.randn_like(layer.weight[idx]) * current_std * scale_factor + current_mean
                layer.weight.data[idx] = new_kernel
                
            reset_count += 1
            
    return reset_count



def reset_kernels_and_neurons_stair(model, current_iter, conv_transition_period,conv_reset_ratio=0.03125, fc_reset_ratio=0, scale_factor=1, init_method='ori_normal'):
    """
    Stop resetting convolutional kernels and fully connected layer neurons in a stepwise manner from shallow to deep layers within the specified number of iterations
    
    Args:
    model: PyTorch model
    current_iter: current training iteration
    conv_transition_period: total number of iterations required to complete resetting all convolutional layers
    conv_reset_ratio: reset ratio for convolutional layers
    fc_reset_ratio: reset ratio for fully connected layers
    scale_factor: scaling factor
    init_method: initialization method, can be 'uniform', 'kaiming_uniform', 'kaiming_normal'

    Returns:
        reset_count: number of layers that have been reset
    """
    if not (0 <= conv_reset_ratio <= 1 and 0 <= fc_reset_ratio <= 1):
        raise ValueError("Reset ratio must be between 0 and 1")

    # Separately get convolutional layers and fully connected layers
    conv_layers = []
    fc_layers = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            conv_layers.append((name, module))
        elif isinstance(module, nn.Linear):
            fc_layers.append((name, module))

    reset_count = 0
    stopped_count = 0  # Used to count the number of layers that have stopped resetting

    # Process convolutional layers
    total_conv = len(conv_layers)
    for depth, (name, layer) in enumerate(conv_layers):
        # Calculate the iteration threshold for this layer to stop resetting
        stop_iter = (depth + 1) * (conv_transition_period / total_conv)

        # If the current iteration has exceeded the stop threshold for this layer, skip this layer
        if current_iter >= stop_iter:
            stopped_count += 1  # Count the number of layers that have stopped resetting
            continue

        current_mean = layer.weight.data.mean().item()
        current_std = layer.weight.data.std().item()

        num_kernels = layer.weight.shape[0]
        num_reset = int(num_kernels * conv_reset_ratio)

        if num_reset > 0:
            reset_indices = random.sample(range(num_kernels), num_reset)
            
            for idx in reset_indices:
                if init_method == 'uniform':
                    new_kernel = torch.zeros_like(layer.weight[idx]).uniform_(
                        current_mean - current_std * scale_factor, 
                        current_mean + current_std * scale_factor
                    )
                elif init_method == 'kaiming_uniform':
                    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
                    bound = np.sqrt(6.0 / fan_in) * scale_factor
                    new_kernel = torch.zeros_like(layer.weight[idx]).uniform_(-bound, bound)
                elif init_method == 'kaiming_normal':
                    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
                    std = np.sqrt(2.0 / fan_in) * scale_factor
                    new_kernel = torch.randn_like(layer.weight[idx]) * std
                elif init_method == 'ori_normal':
                    new_kernel = torch.randn_like(layer.weight[idx]) * current_std * scale_factor + current_mean
                else:
                    raise ValueError("Unsupported initialization method")
                
                layer.weight.data[idx] = new_kernel
            
            reset_count += 1


            
    return reset_count, stopped_count
