# import imp
import os
import math
import datetime
import cv2
import numpy as np

import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from typing import List


def concatenate_weights(weights_list, n_splits=0, random_seed=1):
    """
    Concatenates a flat list of weight matrices into a single vector and stores their original shapes.
    Optionally zeroes out all but a randomly selected chunk of weights based on the number of splits.
    
    Parameters:
    - weights_list: List of weight matrices (numpy arrays).
    - n_splits: Number of chunks to divide the concatenated weights into. 
                If greater than 0, a single chunk_size of weights is randomly kept, and the rest are zeroed out.
    - random_seed: (Optional) Integer seed for reproducibility of the random sampling.
    
    Returns:
    - concatenated_weights: 1D numpy array of all concatenated weights, with some elements possibly zeroed out.
    - shapes: List of original shapes of each weight matrix for reconstruction.
    """
    flattened_weights = []
    shapes = []

    # Flatten each weight matrix and store its shape
    for weight_matrix in weights_list:
        flattened_weights.append(weight_matrix.flatten())
        shapes.append(weight_matrix.shape)

    # Concatenate all flattened weights into a single vector
    concatenated_weights = np.concatenate(flattened_weights)
    
    # Zero out parameters if n_splits is specified
    if n_splits > 0:
        total_length = len(concatenated_weights)
        chunk_size = int(total_length // n_splits)
        # print(f"Chunk size: {chunk_size} out of total {total_length}")
        
        if chunk_size == 0:
            raise ValueError("n_splits is too large, resulting in chunk_size=0.")
        
        # Randomly select chunk_size unique indices to keep
        np.random.seed(random_seed)
        keep_indices = np.random.choice(total_length, size=chunk_size, replace=False)
        
        # Debug statements (optional)
        # print(f"Total length of concatenated weights: {total_length}")
        # print(f"Chunk size (number of weights to keep): {chunk_size}")
        # print(f"Indices to keep: {keep_indices}")
        
        # Create a new concatenated_weights vector with zeros
        concatenated_weights_new = np.zeros_like(concatenated_weights)
        concatenated_weights_new[keep_indices] = concatenated_weights[keep_indices]
        
        return concatenated_weights_new, shapes
    
    return concatenated_weights, shapes

def deconcatenate_weights(flat_vector, shapes):
    """
    Reconstructs the list of weight matrices from the flat concatenated vector based on the provided shapes.
    
    Parameters:
    - flat_vector: 1D numpy array of concatenated weights.
    - shapes: List of shapes for each weight matrix.
    
    Returns:
    - reconstructed_weights: List of weight matrices with their original shapes.
    """
    reconstructed_weights = []
    idx = 0

    for shape in shapes:
        size = np.prod(shape)
        weight_matrix = flat_vector[idx:idx + size].reshape(shape)
        reconstructed_weights.append(weight_matrix)
        idx += size

    return reconstructed_weights

def update_model_parameters(gradients, new_weights):
    """
    Updates the model's gradients with the provided new_weights.

    Parameters:
    - model: The PyTorch model whose parameters are to be updated.
    - new_weights: A list of numpy arrays representing the new weights.
    """
    with torch.no_grad():
        for param, new_weight in zip(gradients, new_weights):
            # Ensure the new weight is a numpy array
            if not isinstance(new_weight, np.ndarray):
                raise TypeError("All elements in new_weights must be numpy arrays.")
            
            # Convert numpy array to torch tensor
            new_weight_tensor = torch.from_numpy(new_weight).to(param.device).type_as(param)
            
            # Ensure the shape matches
            if param.cpu().data.shape != new_weight_tensor.shape:
                raise ValueError(f"Shape mismatch: Parameter shape {param.cpu().data.shape} vs new weight shape {new_weight_tensor.shape}")
            
            # Copy the data
            param.copy_(new_weight_tensor)

def count_zero_nonzero(original_dy_dx):
    """
    Counts the total number of zero and non-zero elements in the list of tensors.

    Parameters:
    - original_dy_dx: List of PyTorch tensors.

    Returns:
    - total_zero: Total number of elements equal to zero.
    - total_non_zero: Total number of elements not equal to zero.
    """
    total_zero = 0
    total_non_zero = 0

    for param in original_dy_dx:
        # Ensure the tensor is on CPU and convert to NumPy array
        data = param.cpu().data.numpy()

        # Count zeros and non-zeros
        total_zero += np.sum(data == 0)
        total_non_zero += np.sum(data != 0)

    print(f"Total number of zero elements: {total_zero}")
    print(f"Total number of non-zero elements: {total_non_zero}")

    return total_zero, total_non_zero
    
    
def get_parameters_from_model(model) -> List[np.ndarray]:
    # Ordered by state_dict() insertion order (stable across same model)
    return [t.detach().cpu().numpy() for _, t in model.state_dict().items()]

def flatten_params(params):
    shape_list = [p.shape for p in params]
    flattened_params = np.concatenate([p.flatten() for p in params])
    return flattened_params, shape_list

def unflatten_params(flattened_params, shape_list):
    params = []
    start_idx = 0
    for shape in shape_list:
        size = np.prod(shape)
        param = flattened_params[start_idx:start_idx + size].reshape(shape)
        params.append(param)
        start_idx += size
    return params

def create_mask(params, n_splits, seed):
    flat_params, shape_list = flatten_params(params)
    aggregators_ass = np.zeros_like(flat_params)
    n_elements_per_aggr = len(aggregators_ass)//n_splits
    rest = len(aggregators_ass) % n_splits
    i = 0
    for aggr in range(0,n_splits):
        fragment_size = n_elements_per_aggr + (1 if aggr < rest else 0)
        aggregators_ass[i:i+fragment_size] = aggr
        i = i + fragment_size
    
    # Create a random generator with the given seed
    gen = np.random.MT19937(seed=seed)
    rng = np.random.Generator(gen)
    # Randomly shuffle the aggregator assignments
    rng.shuffle(aggregators_ass)
    
    return unflatten_params(aggregators_ass, shape_list)

def count_state_params(model: torch.nn.Module) -> int:
    return int(sum(t.numel() for _, t in model.state_dict().items()))

def init_eris_state(
    model: torch.nn.Module,
    fl_rounds: int,
    k: int | None = None,
    k_frac: float | None = None,
) -> tuple[int, int, float, List[np.ndarray]]:
    """
    Returns:
      d: total params
      k: number of kept coordinates for random-k
      gamma: SoteriaFL gamma
      s: reference vector list (zeros), one np.ndarray per tensor
    """
    d = count_state_params(model)
    if k is None:
        if k_frac is not None and k_frac > 0 and k_frac < 1:
            k = max(1, int(d * k_frac))
        else:
            # paper’s heuristic (your earlier code)
            k = max(1, int(d / max(1.0, math.log2(max(2, fl_rounds)))))

    # gamma formula used in your prior implementation
    w = (d / k) - 1.0
    gamma = math.sqrt((1.0 + 2.0 * w) / (2.0 * (1.0 + w) ** 3))

    s = []
    for _, t in model.state_dict().items():
        s.append(np.zeros_like(t.detach().cpu().numpy()))
    return d, k, float(gamma), s

def init_outputfolder(config):
    if not os.path.exists(config.output_folder):
        os.makedirs(config.output_folder)

    current_time = datetime.datetime.now()
    current_time_str = datetime.datetime.strftime(current_time, '%m%d_%H%M')

    output_dir = os.path.join(config.output_folder, current_time_str)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        os.makedirs(os.path.join(output_dir, "img"))

    return output_dir

def save_batch(output_dir, original_img, recon_img, keyword="recon", save_ori=True):
    
    if save_ori:
        for i, img in enumerate(original_img):
            img_numpy = tensor2img(img)
            cv2.imwrite(os.path.join(output_dir, "img", '{:d}_ori.png'.format(i)), img_numpy[:, :, ::-1])

    for i, img in enumerate(recon_img):
        img_numpy = tensor2img(img)
        cv2.imwrite(os.path.join(output_dir, "img", '{:d}_{:s}.png'.format(i, keyword)), img_numpy[:, :, ::-1])

def label_to_onehot(target, num_classes=1000):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

def freeze(model):
    for p in model.parameters():
        p.requires_grad = False

def jaccard(a, b):
    intersection = np.intersect1d(a, b)
    union = a.shape[0] + b.shape[0] - intersection.shape[0]
    return float(intersection.shape[0]) / union

def overlap_idx(a, b):
    intersection = np.intersect1d(a, b)
    # denominator = min(a.shape[0], b.shape[0])
    # denominator = max(a.shape[0], b.shape[0])
    denominator = b.shape[0]
    return float(intersection.shape[0]) / denominator

def preprocess(config, x, y, onehot, model):
    device = config.device
    
    if config.half:
        x = x.half()
        # y = y.half()
        onehot = onehot.half()
        model = model.half()

    return x.to(config.device), y.to(config.device), onehot.to(config.device), model.to(config.device)


def single2tensor4(img):
    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)


def tensor2img(tensor, min_max=(0, 1), out_type=np.uint8):
    '''
    Converts a torch Tensor into an image Numpy array
    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
    '''
    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp
    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
    n_dim = tensor.dim()
    if n_dim == 3:
        img_np = tensor.numpy()
        img_np = np.transpose(img_np[:, :, :], (1, 2, 0))  # HWC, RGB
    elif n_dim == 2:
        img_np = tensor.numpy()
    else:
        raise TypeError(
            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
    if out_type == np.uint8:
        img_np = (img_np * 255.0).round()
        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
    return img_np.astype(out_type)


class Adam16(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        params = list(params)
        super(Adam16, self).__init__(params, defaults)
        # for group in self.param_groups:
            # for p in group['params']:
        
        self.fp32_param_groups = [p.data.float().cuda() for p in params]
        if not isinstance(self.fp32_param_groups[0], dict):
            self.fp32_param_groups = [{'params': self.fp32_param_groups}]

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group,fp32_group in zip(self.param_groups,self.fp32_param_groups):
            for p,fp32_p in zip(group['params'],fp32_group['params']):
                if p.grad is None:
                    continue
                    
                grad = p.grad.data.float()
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = grad.new().resize_as_(grad).zero_()
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], fp32_p)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
            
                # print(type(fp32_p))
                fp32_p.addcdiv_(-step_size, exp_avg, denom)
                p.data = fp32_p.half()

        return loss

    
        
