import copy
import numpy as np
import torch
import random
import math
import struct
import time
import statistics
from models.Fed import FedAvg
from models.Nets import MLP, Mnistcnn, CifarCnn
from models.Sia import SIA, DatasetSplit
from models.Update import LocalUpdate
from utils.dataset import get_dataset, exp_details
from utils.options import args_parser
from torch.utils.data import DataLoader, Dataset
from torchvision import models
import torch.nn as nn
from models.test import test_fun,test_fun_topk
from sklearn.linear_model import LinearRegression
from proposed_mechanism import *

"""
Original file by Hongsheng Hu et al. 
from work Source Inference Attacks: Beyond Membership Inference Attacks in Federated Learning
We modified their code to include the reconstruction attacks [Section 4]. 
"""

#fixed parameters:
MAX_RUNS = 5 #number of executions
PERCN_OF_SHADOW = 0.05 #default: 0.05 percentage of the shadow dataset that the owner has
top_k = 1 #default:1 finds top_k accuracy, 
r = 4 #default:4 digits after the floating point to take into account


#####
#Debugging
#the following lines are used to execute only parts of the code, to speed up the process
RUN_SIA = 1 #default: 1 runs the source infernece attacks
RUN_RECONSTRUCTION = 1 #default: 1 runs the reconstruction attacks of Section 5
RUN_PROPOSED_MECHANISM = 1 #default: 1 runs algorithm 1 (proposed solution)
RUN_ACCURACY = 1 #default 1 finds the accuracy of the proposed mechanism
pern_of_parameter_to_reconc= 1 #default: 1 percentage of the parameters that will be reconstructed from the final layer, decrease this to speed up the execution
SMALLEST_EPOCH_TO_START_REMAPPING = -1 #default:-1 the epoch that we will start doing the SIAs attacks, for complex models like ResNet increase this to 3 to speed up the execution time

#####

names_of_last_fc = [] 
layers_to_remap =  []

#we keep in this dictionary the parameters which we have remapped
clients_remap_parameters = []


remapped_model = {} #we keep in this disctionary the models we have remapped
param_multplier = 1 #used when printing at the end



def reconstruct_model(w_locals, net_glob,dataset_train,  dict_sample_user, TARGET_CLIENT,correct_position, shadow_dataset):
    """
    Reconstructs the model by choosing 
    the one with the highest accuracy on the shadow dataset
    among the possible choices
    """
    best_accuracy = 0
    best_position = -1
    remapped_model[TARGET_CLIENT] = []
    test_model = copy.deepcopy(net_glob)
    # loop through all the models
    # and keep the one that has the highest accuracy
    # on the shadow dataset of TARGET_CLIENT

    for i in range (len(w_locals)):
        test_model.load_state_dict(w_locals[i])
        test_model.eval()
        accuracy = accuracy_on_target_data(test_model,shadow_dataset[TARGET_CLIENT])
        #print(f'Model {i} Accuracy: {accuracy}%')
        if accuracy >= best_accuracy:
            best_accuracy = accuracy
            best_position = i

    remapped_model[TARGET_CLIENT].append(best_position)
    return (best_position==correct_position)



def reconstruct_model_parameter(w_locals, net_glob,dataset_train,  dict_sample_user, TARGET_CLIENT,args,w_glob,shadow_dataset):
    """
    Reconstructs the model 
    by choosing the best accuracy
    of each parameter of the last layer, FedAvg the others (same as reconstruct_model)
    """

    corrects = 0
    totals = 0
    print("*** Remapping parameters of client ", TARGET_CLIENT)
    w_with_param = copy.deepcopy(w_glob)
    test_model = copy.deepcopy(net_glob)
    


    flattened_params = {
        i: {
            layer: torch.flatten(w_locals[i][layer]).tolist()
            for layer in layers_to_remap
        }
        for i in range(len(w_locals))
    }
    #find the possible choices for each parameter/layer

    for layer_name in layers_to_remap:
        remapped_parameters = []
        #print("currently in" , layer_name)
        #print("Parameters in ", layer_name, "are:", len(parameters))
        w_with_param = test_model.state_dict()

        times_avg = []
    

        #quick approximation
        """
        for param in range(len(flattened_params[0][layer_name])):
            #pick a random position
            randomness = random.random()
            if (randomness>0.3):
                random_pos = random.randint(0,len(w_locals)-1)
                remapped_parameters.append(flattened_params[random_pos][layer_name][param])
            #remap correct
            else:
                remapped_parameters.append(flattened_params[TARGET_CLIENT][layer_name][param])
        """
        #slower, actually do the remapping
        for param in range(int(len(flattened_params[0][layer_name]) * pern_of_parameter_to_reconc)):
            #print("-----")
            possible_choices = np.array([flattened_params[i][layer_name][param] for i in range(len(w_locals))])

            is_similar = np.allclose(possible_choices, np.mean(possible_choices), rtol=0.00001, atol=0.0001)

            if (is_similar):
                remapped_parameters.append(possible_choices[0])
                continue
            best_accuracy = 0
            best_position = -1
            
            for i in range(len(possible_choices)):
                flat_index = np.unravel_index(param, w_with_param[layer_name].shape)
                chosen_value = torch.tensor(possible_choices[i], dtype=w_with_param[layer_name].dtype, device=w_with_param[layer_name].device)
                
                w_with_param[layer_name][flat_index] = chosen_value
                test_model.load_state_dict(w_with_param)
                accuracy = accuracy_on_target_data(test_model, shadow_dataset[TARGET_CLIENT])
                #print("i is", i, "accuracy is", accuracy, "on value", chosen_value)
                # Check for the best accuracy
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    best_position = i

            # choose the best parameter value
            chosen = possible_choices[best_position]
            remapped_parameters.append(float(chosen))
            flat_index = np.unravel_index(param, w_with_param[layer_name].shape)
            chosen_value = torch.as_tensor(chosen, dtype=w_with_param[layer_name].dtype, device=w_with_param[layer_name].device)
            w_with_param[layer_name][flat_index] = chosen_value


            # Check if the chosen parameter matches the target client or is close enough
            #if (best_position == TARGET_CLIENT) or abs(chosen - flattened_params[TARGET_CLIENT][layer_name][param]) <= 0.0001:
            #    corrects += 1
        # Store the remapped parameters for the target client
        clients_remap_parameters[TARGET_CLIENT][layer_name] = remapped_parameters
        test_model.load_state_dict(w_with_param)
        test_model.eval()

    return corrects



def accuracy_on_target_data(test_model, shadow_dataset):
    """
    Computes accuracy of test_model on the shadow dataset.
    Optimized for performance.
    """
    correct = 0
    total = 0

    test_model.eval()
    with torch.no_grad():
        for images, labels in shadow_dataset:
            images, labels = images.to(args.device), labels.to(args.device)
            outputs = test_model(images)

            _, topk_preds = outputs.topk(top_k, dim=1, largest=True, sorted=True)
            correct += torch.sum(torch.any(topk_preds == labels.view(-1, 1), dim=1)).item()
            total += labels.size(0)

    # Calculate accuracy
    accuracy = 100 * correct / total if total != 0 else 0
    return accuracy


def avg(alist):
    return sum(alist)/len(alist)


# parse args
args = args_parser()
print("users = ", args.num_users)
print("alpha =", args.alpha)
print("Local Epochs=", args.local_ep)
MODE = args.mode # MODEL: shuffles per model | LAYER: shuffle per layer | PARAMETER: shuffler per param

if (MODE not in ["MODEL", "LAYER", "PARAMETER"]):
    print("Reconstruction attack should run by setting mode to either MODEL, LAYER or PARAMETER")
    exit()

print(">", MODE, " Reconstruction Attack!")



args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

num_of_succeses = 0
SIA_attacks_before =[]
SIA_attacks_after = []
SIA_attacks_shadow = []
SIA_attacks_prop_defense = []

executed_rounds = MAX_RUNS

total_distances = []
total_size_weights = []
total_SIAs = []



for run in range(MAX_RUNS):
    clients_remap_parameters = []
    for i in range(args.num_users): #change this to number of users:
        remaped_parameters_tuple = {}

        for layer in layers_to_remap:
            remaped_parameters_tuple[layer] = []
        clients_remap_parameters.append(remaped_parameters_tuple)

    remapped_model = {}
    print("[Run: ", run, "/", MAX_RUNS, "]")
    # load dataset and split data for users
    dataset_train, dataset_test, dict_party_user, dict_sample_user = get_dataset(args)

    # build model

    if args.model == 'cnn' and args.dataset == 'MNIST':
        net_glob = Mnistcnn(args=args).to(args.device)
        names_of_last_fc = ["fc3.weight", "fc3.bias"] 
        layers_to_remap =  ["fc3.weight","fc3.bias"]

    elif args.model == 'cnn' and args.dataset == 'CIFAR10':
        net_glob = CifarCnn(args=args).to(args.device)
        names_of_last_fc = ["fc3.weight", "fc3.bias"] 
        layers_to_remap =  ["fc3.weight","fc3.bias"]

    elif args.model == 'cnn' and args.dataset == 'CIFAR100': 
        #net_glob = CifarCnn(args=args).to(args.device) 
        net_glob = models.resnet18(pretrained=False) 
        num_features = net_glob.fc.in_features 
        net_glob.fc = nn.Linear(num_features, 100) 
        net_glob = net_glob.to(args.device)

        names_of_last_fc = ["fc.weight", "fc.bias"] 
        layers_to_remap =  ["fc.weight", "fc.bias"] 

    elif args.model == 'mlp':
        len_in = 1
        dataset_train = dataset_train.dataset
        dataset_test = dataset_test.dataset
        img_size = dataset_train[0][0].shape

        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)

        names_of_last_fc =  ["layer_hidden.weight", "layer_hidden.bias"]
        layers_to_remap =["layer_hidden.weight", "layer_hidden.bias"]
    else:
        exit('Error: unrecognized model')

    empty_net = net_glob
    net_glob.train()
    net_glob_encoded = copy.deepcopy(net_glob)
    total_parameters = 0

    #w_glob = {name: torch.zeros_like(param) for name, param in net_glob.named_parameters()}
    size_per_client = []
    for i in range(args.num_users):
        size = len(dict_party_user[i])
        size_per_client.append(size)
        #print("Size of client ",i, "is ", size)
    total_size = sum(size_per_client)
    size_weight = np.array(np.array(size_per_client) / total_size)
    # copy weights
    w_glob =  net_glob.state_dict()
    encoded_w_glob = net_glob_encoded.state_dict()
    ### training
    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]
        w_locals_encoded = [w_glob for i in range(args.num_users)]
    

    skip=0
    best_SIA_vanilla =0 
    best_SIA_of_recon =0 

    for curr_epoch in range(args.epochs):
        print("***** EPOCH:",curr_epoch,"******")
        #local_updates = []
        #loss_locals = []
        if not args.all_clients:
            w_locals = []
        for idx in range(args.num_users):
            try:
                print("--- Training client ", idx)
                local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_party_user[idx], shadow = False, PERCN_OF_SHADOW = PERCN_OF_SHADOW)   
                #local_updates.append(local)
                w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
                if args.all_clients:
                    w_locals[idx] = copy.deepcopy(w)
                else:
                    w_locals.append(copy.deepcopy(w))
                #loss_locals.append(copy.deepcopy(loss))

                #run for our mechanism
                if (RUN_ACCURACY==1):
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_party_user[idx], shadow = False, PERCN_OF_SHADOW = PERCN_OF_SHADOW)   
                    w, loss = local.train(net=copy.deepcopy(net_glob_encoded).to(args.device))
                    if args.all_clients:
                        w_locals_encoded[idx] = copy.deepcopy(w)
                    else:
                        w_locals_encoded.append(copy.deepcopy(w))

            except Exception as e:
                print("Error occured when training the model, probably the dirichlet distribution returned no values for a client")
                print(e)
                skip = 1
                break
        if skip == 1:
             print("Skipping this round")
             executed_rounds -=1
             continue


        w_glob = FedAvg(w_locals, size_weight)
        net_glob.load_state_dict(w_glob)
        net_glob.eval()
        # shuffle models
        order = [i for i in range(len(w_locals))]
        random.shuffle(order)

        new_w_locals = []

        if (curr_epoch <=SMALLEST_EPOCH_TO_START_REMAPPING):
            continue   

        #acc_train, loss_train_ = test_fun_topk(net_glob, dataset_train, args,top_k=top_k)
        #acc_test, loss_test = test_fun_topk(net_glob, dataset_test, args,top_k=top_k)
        #print("Training accuracy of the joint model: {:.2f}".format(acc_train))
        #print("Testing accuracy of the joint model: {:.2f}".format(acc_test))

        ## Defeat shuffling by mapping the models back to their original owners (Section 5)
        if RUN_RECONSTRUCTION == 1:
            if MODE == "MODEL":
                #shuffle all models
                for i in order:
                    new_w_locals.append(copy.deepcopy(w_locals[i]))
            elif MODE == "LAYER":
                #shuffle all the FC3's 
                #we will use the FedAvg for FC1/FC2/Conv
                for i in order:
                    a_model = copy.deepcopy(w_glob)
                    a_model[names_of_last_fc[0]] = copy.deepcopy(w_locals[i][names_of_last_fc[0]])
                    a_model[names_of_last_fc[1]] = copy.deepcopy(w_locals[i][names_of_last_fc[1]])
                    new_w_locals.append(a_model)
            elif MODE == "PARAMETER":
                new_w_locals = copy.deepcopy(w_locals)
            else:
                print("Wrong mode...")
                exit()

            #get the shadow_datasets
            shadow_dataset = []
            for client in range(len(w_locals)):
                db_splited = DatasetSplit(dataset_train, dict_party_user[client])
                db_to_use = [db_splited[i] for i in range(int(PERCN_OF_SHADOW * len(db_splited))+1)]
                shadow_dataset.append(DataLoader(db_to_use, batch_size=64, shuffle=False))

            if MODE in ["MODEL", "LAYER"]:
                for TARGET_CLIENT in range(len(w_locals)):
                    correct_position = order.index(TARGET_CLIENT)
                    success = reconstruct_model(new_w_locals, copy.deepcopy(net_glob), dataset_train, dict_party_user, TARGET_CLIENT,correct_position, shadow_dataset)
                    if success == True:
                        num_of_succeses+=1
            else:
                for TARGET_CLIENT in range(len(w_locals)): 
                    result = reconstruct_model_parameter(w_locals, copy.deepcopy(net_glob),dataset_train,  dict_party_user, TARGET_CLIENT,args,w_glob,shadow_dataset)
                    #print("Reconstructed", result, "of the model.")
                    num_of_succeses+=result

        #Run the SIA attack
        if (RUN_SIA == 1): 
            empty_net_2 = copy.deepcopy(empty_net)
            empty_net_3 = copy.deepcopy(empty_net)
            empty_net_4 = copy.deepcopy(empty_net)
            #initial SIA attack
            
            SIA_attack = SIA(args=args, w_locals=w_locals, dataset=dataset_train, dict_mia_users=dict_sample_user) 
            attack_acc_initial = SIA_attack.attack(net=empty_net.to(args.device))
            SIA_result = max(round(attack_acc_initial.item(),2),(1/args.num_users)*100)
            #keep the best accuracy out of all epochs
            if (SIA_result>best_SIA_vanilla):
                best_SIA_vanilla = SIA_result
            print("[Vanilla FL] SIA accuracy :",SIA_result,"%")

            reconstructed_w_locals = []
            if RUN_RECONSTRUCTION == 0:
                reconstructed_w_locals = w_locals
            else:
                if (MODE == "MODEL"):
                    for i in range(len(w_locals)):
                        pos = remapped_model[i][0]
                        #print("Remapping user", i, "to ", pos)
                        reconstructed_w_locals.append(copy.deepcopy(new_w_locals[pos]))
                else:
                    reconstructed_w_locals = [copy.deepcopy(w_glob) for i in range(len(w_locals))]
                    if (MODE == "LAYER"):
                        reconstructed_w_locals = [copy.deepcopy(w_glob) for i in range(len(w_locals))]
                        #avg all layers but not FC3
                        for i in range(len(w_locals)):
                            pos = remapped_model[i][0]
                            reconstructed_w_locals[i][names_of_last_fc[0]] = copy.deepcopy(new_w_locals[pos][names_of_last_fc[0]])
                            reconstructed_w_locals[i][names_of_last_fc[1]] = copy.deepcopy(new_w_locals[pos][names_of_last_fc[1]])              
                    elif (MODE == "PARAMETER"):
                        #take the avg of every layer but
                        #take the selected values from the shadow model, stored in the dictionaries
                        #now reconstruct the correct parameters for some of the FC3 paramters
                        param_multplier=0 
                        for i in range(len(w_locals)):
                            for layer_name in layers_to_remap:
                                #find parameters
                                parameters = reconstructed_w_locals[0][layer_name].view(-1).tolist()
                                #param_multplier+= len(parameters)
                          
                                for param in range(int(len(parameters)*pern_of_parameter_to_reconc)):
                                    reconstructed_w_locals[i][layer_name].view(-1)[param] = clients_remap_parameters[i][layer_name][param]
                        
            #run the attack on the reconstructed
            
            SIA_attack = SIA(args=args, w_locals=reconstructed_w_locals, dataset=dataset_train, dict_mia_users=dict_sample_user) 
            attack_acc_after = SIA_attack.attack(net=empty_net_4.to(args.device))
            SIA_result = max(round(attack_acc_after.item(),2),(1/args.num_users)*100)
            if (SIA_result>best_SIA_of_recon):
                best_SIA_of_recon = SIA_result
            print("[Reconstructed Models (Section 5)] SIA accuracy :",SIA_result,"%")
           
            

        if (RUN_PROPOSED_MECHANISM==1):
            co_primes = find_coprimes(args.num_users*((10**r)-1))
            joint_model = []

            #instead of encoding than summing as per Algorithm 8 it is faster to first sum and then encode, i.e. take the w_glob
            encoded_w_glob = copy.deepcopy(FedAvg(w_locals_encoded, size_weight))
            for layer_name in w_glob:
                parameters = torch.flatten(w_glob[layer_name]).tolist()
                for param in range(int(len(parameters))):
                    encoded_w_glob[layer_name].view(-1)[param] = RNS_DECODE(RNS_ENCODE(parameters[param],r,co_primes),r ,co_primes)
            SIA_attack = SIA(args=args, w_locals=[encoded_w_glob for i in range(len(w_locals))], dataset=dataset_train, dict_mia_users=dict_sample_user) 
            attack_acc = SIA_attack.attack(net=empty_net_4.to(args.device))
            SIA_result = max(round(attack_acc.item(),2),(1/args.num_users)*100)
            print("[Proposed Solution (Alg. 1)] SIA accuracy :",SIA_result,"%")
            SIA_attacks_prop_defense.append(SIA_result)
            net_glob_encoded.load_state_dict(encoded_w_glob)
            net_glob_encoded.eval()


        acc_train, loss_train_ = test_fun_topk(net_glob, dataset_train, args,top_k=top_k)
        acc_test, loss_test = test_fun_topk(net_glob, dataset_test, args,top_k=top_k)
        print("[Vanilla FL] Training accuracy of the joint model: {:.2f}%".format(acc_train))
        print("[Vanilla FL] Testing accuracy of the joint model: {:.2f}%".format(acc_test))

        if (RUN_PROPOSED_MECHANISM==1):
            acc_train_encoded, loss_train_ = test_fun_topk(net_glob_encoded, dataset_train, args,top_k=top_k)
            acc_test_encoded, loss_test = test_fun_topk(net_glob_encoded, dataset_test, args,top_k=top_k)
            print("[Proposed Solution (Alg. 1)] Training accuracy of the joint model: {:.2f}%".format(acc_train_encoded))
            print("[Proposed Solution (Alg. 1)] Testing accuracy of the joint model: {:.2f}%".format(acc_test_encoded))
        print("------")
    #append the maximum accuracy
    SIA_attacks_before.append(max(best_SIA_vanilla,(1/args.num_users)*100))
    SIA_attacks_after.append(max(best_SIA_of_recon,(1/args.num_users)*100))

# experiment setting
exp_details(args)
print("MODE = ", MODE)
print("MAX RUNS (actually executed) = ", executed_rounds)


if RUN_SIA == 1:
    print("Initial SIA accuracy:", avg(SIA_attacks_before))
    print("After reconstruction SIA accuracy:", avg(SIA_attacks_after))

if RUN_PROPOSED_MECHANISM==1:
    print("SIA accuracy on proposed defense:", avg(SIA_attacks_prop_defense))

print("SHADOW MODL PERCN: ", PERCN_OF_SHADOW )
print("with db= ",args.dataset)
print("users = ", args.num_users)
print("alpha =", args.alpha)
print("Local Epochs=", args.local_ep)

