import copy
import torch
from torchvision import datasets, transforms

import sys

from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
from update import LocalUpdate, test_inference, DatasetSplit
from math import exp
import numpy as np
from numpy import linalg
from options import args_parser
import math
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor,Resize,Normalize
import pdb


def get_dataset(args):
    """Loads dataset and partitions data among users."""
    from torchvision import datasets, transforms

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                         transform=apply_transform)
        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                        transform=apply_transform)

        if args.iid:
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            if args.unequal:
                raise NotImplementedError("Unequal splits for CIFAR not implemented.")
            else:
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset in ['mnist', 'fmnist']:
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        if args.dataset == 'mnist':
            data_dir = '../data/mnist'
            DatasetClass = datasets.MNIST
        else:  # fmnist
            data_dir = '../data/fmnist'
            DatasetClass = datasets.FashionMNIST

        train_dataset = DatasetClass(data_dir, train=True, download=True,
                                     transform=apply_transform)
        test_dataset = DatasetClass(data_dir, train=False, download=True,
                                    transform=apply_transform)

        if args.iid:
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            if args.unequal:
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                user_groups = mnist_noniid(train_dataset, args.num_users)

    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    return train_dataset, test_dataset, user_groups


def average_weights(w):
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.sign(w_avg[key]) 
    return w_avg

def sign_attack(w,scale=1):
    w_avg = copy.deepcopy(w)
    for key in w_avg.keys():
        w_avg[key] = -w[key] * scale 
    return w_avg

def exp_details(args):
    print('\nExperimental details:')
    print(f'    Dataset     : {args.dataset}')
    print(f'    Model     : {args.model}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')

    print('    Federated parameters:')
    if args.iid == 1:
        print('    IID')
    elif args.iid == 0:
        print('    Non-IID')

    else:
        print('    Non-IID')
    if args.unequal:
        print('    Unbalanced')
    else:
        print('    balanced')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')


    return





