from typing import Dict, List
from scipy.optimize import fsolve
import numpy as np
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

def update_x_hat(x, x_hat, t):
    x_hat[0] = x  # final weight
    x_hat[1] = (x + t * x_hat[1]) / (t + 1)  # uniform average
    x_hat[2] = 2 / (2 + t) * x + t / (2 + t) * x_hat[2]  # linearly increasing average
    x_hat[3] = 6 * (t + 1) / (t + 2) / (2 * t + 3) * x + t * (1 + 2 * t) / (t + 2) / (2 * t + 3) * x_hat[3]  # quadratic++++++
    return x_hat


def uniform_data_catch(feature, label):
    index = np.random.choice(len(feature))
    x = feature[index]
    y = label[index]
    return x, y


def sigmoid(scores):
    return 1 / (1 + np.exp(-scores))


def learning_rate(t, exp, cte):
    if cte==True:
        return 2**-exp
    #return #1 / (t / 10 + 10)
    if t<cte[0]:
        return 2 ** -exp
    if cte[0]<=t<cte[1]:
        return (2**-exp)/cte[2]
    return 2**-exp/cte[2]/cte[2]



def fetch(df, name, d):
    feature = []
    label = []
    for ind in df.index:
        x, y = extract(df[name][ind])
        feature.append([int(i in x) for i in range(d)])
        label.append(y)
    return np.array(feature), np.array(label)


def extract(st):
    splitted = st.split()
    y = int(splitted[0])
    x = []
    for item in splitted[1:]:
        x.append(int(item.split(":")[0]) - 1)
    return x, y


def softmax(u):
    expu = np.exp(u)
    return expu / np.sum(expu)


def matrix_softmax(u):
    expu = np.exp(u)
    return expu / np.sum(expu, axis=1).reshape(u.shape[0], 1)


def logistic_regression(identity, total_data, weight, feature, label, lr, optimizer, device, criterion):
    if identity[0] == "quadratic":
        mu = identity[1][label]
        sigma = identity[2]
        if weight >= 1:
            gradient = 2 * (weight - 1) + np.random.normal(mu, sigma)
        else:
            gradient = (weight - 1) + np.random.normal(mu, sigma)
        weight = weight - lr * gradient
        return weight, gradient
    if identity[0]=="mnist":
        q = softmax(weight @ feature)
        gradient = np.outer(q - label, feature)
        weight = weight * (1 - lr / total_data) - lr * gradient
        return weight, gradient
    if identity[0] == "w8a":
        score = np.dot(feature, weight)
        temp = -label / (1 + np.exp(label * score))
        gradient = (feature * temp)
        weight = weight * (1 - lr / total_data) - lr * gradient
        return weight, gradient
    if identity[0] == "cifar10":
        update_lr(optimizer, lr)
        if True:#with torch.autocast(device_type='cuda', dtype=torch.float16):
            images = feature.to(device)
            labels = label.to(device)
            # Forward pass
            outputs = weight(images)
            one_loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            one_loss.backward()
            optimizer.step()
            # print("loss:", one_loss.item())
        return weight, None




def loss(identity, data, weight, criterion,device,test_loader):
    if identity[0] == "quadratic":
        if weight[0]>= 1:
            return (weight[0] - 1) ** 2
        return  (weight[0] - 1) ** 2 /2
    if identity[0] == "mnist":
        feature = data[0]
        label = data[1]
        total_data = len(label)
        """
        Y: onehot encoded
        """
        Z = matrix_softmax(np.dot(feature, weight.T))
        l = np.sum(-1 * label * np.log(Z)) / total_data + 1 / 2 / total_data * np.linalg.norm(weight) ** 2
        return l
    if identity[0] == "w8a":
        feature = data[0]
        label = data[1]
        total_data = len(label)
        scores = np.dot(feature, weight)
        l = np.average(np.log(1 + np.exp(-label * scores))) + 1 / 2 / total_data * np.dot(weight, weight)
        return l
    if identity[0] == "cifar10":
        l =0
        weight.eval()
        with torch.no_grad():
            if True:#with torch.autocast(device_type='cuda', dtype=torch.float16):
                for (images, labels) in data:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = weight(images)
                    one_loss = criterion(outputs, labels)
                    l += one_loss.item()
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = weight(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            # print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
        return (l/len(data),correct / total)





def floyd_warshall(graph: Dict[int, List[int]], cost):
    distance = {i: {j: None for j in graph} for i in graph}
    next = {i: {j: None for j in graph} for i in graph}
    for i in graph:
        adjs_i = graph[i]
        cost_i = cost[i]
        for j, c in zip(adjs_i, cost_i):
            distance[i][j] = c
            next[i][j] = j

    for j in graph:
        for i in graph:
            for k in graph:
                if distance[i][j] is None or distance[j][k] is None:
                    continue
                new_distance = distance[i][j] + distance[j][k]
                if distance[i][k] is None or distance[i][k] > new_distance:
                    distance[i][k] = new_distance
                    next[i][k] = next[i][j]
    for i in graph:
        distance[i][i] = 0
    max_path = {}
    # print(distance)
    for i in graph:
        max_path[i] = max(distance[i].values())
    center = min(max_path, key=max_path.get)
    shortest_path_graph = {i: [] for i in graph}
    for i in graph:
        j = center
        while j != i:
            k = next[j][i]
            if k not in shortest_path_graph[j]:
                shortest_path_graph[j].append(k)
            if j not in shortest_path_graph[k]:
                shortest_path_graph[k].append(j)
            j = k
    return shortest_path_graph, center, min(max_path.values())


def find_streams(tree, root, parent):
    streams = []
    has_a_child = parent != -1 and len(tree[root]) == 2
    for adj in tree[root]:
        if adj == parent:
            continue
        child = adj
        child_stream, other_streams = find_streams(tree, root=child, parent=root)
        streams.extend(other_streams)
        child_stream.append(root)
        streams.append(child_stream)
    if has_a_child:
        return streams[-1], streams[:-1]
    else:
        return [root], streams

def max_depth(tree,root,parent):
    max_child_depth = 0
    for child in tree[root]:
        if child == parent:
            continue
        child_depth = max_depth(tree,child,root)
        if child_depth > max_child_depth:
            max_child_depth = child_depth
    return max_child_depth + 1

def unbalanced_data_split(num_of_node, total_data, first_node_data_size, num_of_chunk):
    def func(q):
        return first_node_data_size * (q**(num_of_node// num_of_chunk))/(q-1) - total_data // num_of_chunk
    q = fsolve(func, 3)
    # print(q)
    data_split = [(0, first_node_data_size)]
    for n in range(1, num_of_node// num_of_chunk):
        data_split.append((data_split[n - 1][1], int(data_split[n - 1][1]+data_split[0][1] * q ** n)))
    data_split[-1] = (data_split[-1][0], total_data // num_of_chunk)
    for c in range(1, num_of_chunk):
        for n in range(num_of_chunk):
            data_split.append((data_split[c * num_of_chunk + n - 1][1],
                               data_split[c * num_of_chunk + n - 1][1] + data_split[n % num_of_chunk][1] -
                               data_split[n % num_of_chunk][0]))
    print(data_split)
    return data_split

def aggrigate(models,ws):
    sd=[]
    for model in models:
        sd.append(model.state_dict())
    result_sd = copy.deepcopy(sd[0])
    # Aggrigate all parameters
    for key in sd[0]:
        result_sd[key] = sd[0][key].float() * ws[0]
        for i in range(1,len(models)):
            result_sd[key] += sd[i][key].float()* ws[i]
    result_model = copy.deepcopy(models[0])
    result_model.load_state_dict(result_sd)
    return result_model

def differ(models,final,ws,layers):
    models_sd=[]
    for model in models:
        models_sd.append(model.state_dict())
    final_sd = final.state_dict()
    d=[0 for layer in layers]
    # Computing norm of difference
    for i in range(len(layers)):
        layer = layers[i]
        for j in range(1, len(models)):
            temp = models_sd[j][layer] - final_sd[layer]
            d[i] += (torch.norm(temp,p=2)**2).item() * ws[j]
    return d

# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=False)

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
##########################
### MODEL
##########################


# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=False)

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
