import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from opacus.validators.module_validator import ModuleValidator
from argparse import ArgumentParser
from tqdm import tqdm


def get_args():
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--bs', type=int, default=64)
    parser.add_argument('--eta', type=float, default=0.01)
    parser.add_argument('--cp', type=int, default=1)
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--seed', type=int, default=-1)
    parser.add_argument('--n', type=int, default=20)
    parser.add_argument('--alpha', type=float, default=-1.0)
    parser.add_argument('--alg', type=str, default="x")
    parser.add_argument('--run', type=int, default=-1)
    parser.add_argument('--json', type=bool, default=False)
    parser.add_argument('--model_type', type=str, default='resnet')
    return parser.parse_args()


class LinearCNN(nn.Module):
    def __init__(self, input_dim, out_channel, patch_num):
        super(LinearCNN, self).__init__()
        self.conv1 = nn.Conv1d(1, out_channel*2, int(input_dim/patch_num), int(input_dim/patch_num), bias = False)
        self.out_channel = out_channel
    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = torch.sum(x,2)
        output = torch.stack([torch.mean(x[:,:self.out_channel],1), torch.mean(x[:,self.out_channel:],1)]).transpose(1,0)
        return output

class Resnet(nn.Module):
    """RESNET model with BatchNorm replaced with GroupNorm"""

    def __init__(self, num_classes, resnet_size, pretrained=False):
        super().__init__()

        # Retrieve resnet of appropriate size
        resnet = {
            18: models.resnet18,
            34: models.resnet34,
            50: models.resnet50,
            101: models.resnet101,
            152: models.resnet152,
        }
        assert (
            resnet_size in resnet.keys()
        ), f"Resnet size {resnet_size} is not supported!"

        self._name = f"Resnet{resnet_size}"
        self.backbone = resnet[resnet_size]()

        if pretrained:
            self.backbone = models.resnet18(pretrained = True)

        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
        self.backbone = ModuleValidator.fix(self.backbone)

    def forward(self, x):
        return self.backbone(x)

    def name(self):
        return self._name
    

def initialize_weights(m):
    if isinstance(m, nn.Conv1d):
        torch.nn.init.normal_(m.weight.data, mean=0.0, std= 0.01)


def __getDirichletData__(y, n, alpha, num_c):
        n_nets = n
        K = num_c

        labelList_true = y


        min_size = 0
        N = len(labelList_true)
        # np.random.seed(rnd)

        net_dataidx_map = {}

        p_client = np.zeros((n,K))

        for i in range(n):
          p_client[i] = np.random.dirichlet(np.repeat(alpha,K))

        idx_batch = [[] for _ in range(n_nets)]

        for k in range(K):
            idx_k = np.where(labelList_true == k)[0]

            np.random.shuffle(idx_k)

            proportions = p_client[:,k]

            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]

        for j in range(n_nets):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]

        net_cls_counts = {}

        for net_i, dataidx in net_dataidx_map.items():
            unq, unq_cnt = np.unique(labelList_true[dataidx], return_counts=True)
            tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
            net_cls_counts[net_i] = tmp

        local_sizes = []
        for i in range(n_nets):
            local_sizes.append(len(net_dataidx_map[i]))
        local_sizes = np.array(local_sizes)
        weights = local_sizes / np.sum(local_sizes)

        print('Data statistics: %s' % str(net_cls_counts))
        print('Data ratio: %s' % str(weights))

        return idx_batch

def get_activations(model, data):

  activation_maps = {}
  hooks = []

  def get_activation(name):
      def hook(model, input, output):
          activation_maps[name] = output.detach().clone()  # Detach the output from the computation graph
      return hook
  
  for name, module in model.named_modules():
    if 'conv' in name or ('features' in name and isinstance(module, nn.Conv2d)):
        hook = module.register_forward_hook(get_activation(name))
        hooks.append(hook)

  with torch.no_grad():
    output = model(data)

  for hook in hooks:
    hook.remove()

  return activation_maps


def test_img(net_g, datatest, args, use_squared = False):

    net_g.eval()
    test_loss = 0
    correct = 0
    size = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)

    for (data, target) in data_loader:
        data, target = data.to(args.device), target.to(args.device)
        data = data.view(data.size(0), 1, -1)  # Flatten to (batch, 1, 3072)
        target = (target == 1).long()  # Convert labels to 0,1
        
        log_probs = net_g(data)
        size += len(target)
        if(use_squared == False):
            test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        else:
            target_new = torch.nn.functional.one_hot(target, num_classes = 10).float()
            test_loss += F.mse_loss(log_probs, target_new, reduction='sum').item()

        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    correct = correct.item()
    accuracy = 100.00 * correct / size
    net_g = net_g.to('cpu')
    return accuracy, test_loss



def measure_alignment_linear_cnn(dataloader, net_1, net_2):
    """Measure alignment between two LinearCNN models"""
    ind = []
    count = 0
    num_filters = 0

    with torch.no_grad():
        for idx, (data, target) in enumerate(dataloader):
            
            # Flatten data for LinearCNN
            data = data.view(data.size(0), 1, -1)
            
            # Get conv1 layer activations
            x1 = net_1.conv1(data)
            x2 = net_2.conv1(data)
            
            # Get signs of activations
            signs1 = torch.sign(x1)
            signs2 = torch.sign(x2)
            
            # Calculate alignment per filter
            alignment = torch.mean(signs1 * signs2, dim=(0, 2))
            # print (alignment)
            num_filters += len(alignment)
            misaligned_index = torch.where(alignment < 0)[0].cpu()
            count += len(misaligned_index)
            ind.append(misaligned_index)
            
            net_1, net_2 = net_1.to('cpu'), net_2.to('cpu')
            return ind, count/num_filters, count


def measure_alignment(dataloader, net_1, net_2):

  
  ind = []
  count = 0
  num_filters = 0
  total_d = 0
  
  with torch.no_grad():

    for idx, (data, target) in enumerate(dataloader):
        data, target = data.to('cuda'), target.to('cuda')
        net_1, net_2 = net_1.to('cuda'), net_2.to('cuda')
        activations_1 = get_activations(net_1, data)
        activations_2 = get_activations(net_2, data)
        for key in activations_1:
          shape = activations_1[key].shape
          num_filters += shape[1]
          activations_1[key] = torch.transpose(activations_1[key], 0, 1)
          activations_2[key] = torch.transpose(activations_2[key], 0, 1)
          activations_1[key] = torch.sign(activations_1[key].reshape(shape[1], -1))
          activations_2[key] = torch.sign(activations_2[key].reshape(shape[1], -1))
          d = torch.mean(activations_1[key]*activations_2[key], dim =1)
          total_d += torch.sum(d)
          misaligned_index = torch.where(d<0)[0].cpu()
          count += len(misaligned_index)
          ind.append(misaligned_index)
        print (num_filters, count)
        net_1, net_2 = net_1.to('cpu'), net_2.to('cpu')
        return ind, count/num_filters, total_d
    