import os, argparse, json, copy, time

# import matplotlib.pyplot as plt
from tqdm import tqdm
from functools import partial
import torch, torchvision
import numpy as np
import torch.nn as nn
from collections import OrderedDict
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from scipy.ndimage.interpolation import rotate as scipyrotate
import wandb

from args import parse_argument



args = parse_argument()

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def eval_epoch(model, loader):
    running_loss, samples = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            loss = nn.CrossEntropyLoss()(model(x), y)
            running_loss += loss.item() * y.shape[0]
            samples += y.shape[0]
        running_loss = running_loss / samples
    return running_loss


def train_loss(model, loader):
    model.eval()
    running_loss, samples = 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        loss = nn.CrossEntropyLoss()(model(x), y)
        running_loss += loss.item() * y.shape[0]
        samples += y.shape[0]
    return running_loss / samples


def train_op(model, loader, optimizer, epochs,  quant_fn=None, lambda_fedprox=0.0, id=None):
    model.train()
    running_loss, samples = 0.0, 0
    weight_Q = quant_fn['weight_Q']
    # acc_Q = quant_fn['acc_Q']
    grad_Q = quant_fn['grad_Q']

    if lambda_fedprox > 0.0:
        W0 = {k: v.detach().clone() for k, v in model.named_parameters()}
    #
    # quant_error = 0
    for ep in range(epochs):

        for it, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = nn.CrossEntropyLoss()(model(x), y)
            if lambda_fedprox > 0.0:
                loss += lambda_fedprox * torch.sum(
                    (flatten(W0).cuda() - flatten(dict(model.named_parameters())).cuda()) ** 2)
            running_loss += loss.item() * y.shape[0]
            samples += y.shape[0]
            loss.backward()
            ## test
            if args.grad_clip:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_to
                )

            with torch.no_grad():
                for name, param in model.named_parameters():
                    param.grad.data = grad_Q(param.grad.data).data
                # for name, param in model.named_parameters():
                #     param.data = model.weight_acc[name]
            optimizer.step()
            # with torch.no_grad():
            #     for name, param in model.named_parameters():
            #         model.weight_acc[name] = acc_Q(param.data).data
            #         param.data = weight_Q(model.weight_acc[name]).data
            # print('---------------------------------------------------------')
            with torch.no_grad():
                for name, p in model.named_parameters():
                    # print(name, p.data.size())
                    #
                    # weight = p.data.clone()
                    p.data = weight_Q(p.data).data
                    #
                    # quant_error += torch.norm(weight - p.data, p = 2)

            if id==args.client_id and args.test_client:
                wandb.log({'Client loss': loss,})


    # print('--------------------------------')
    # print(quant_error)
    #
    return {"loss": running_loss / samples}
    # return {"loss": running_loss / samples, "quant_error": quant_error}


def train_op_scaf(model, server, loader, optimizer, epochs,  quant_fn=None, lambda_fedprox=0.0, val=None):
    model.train()
    running_loss, samples = 0.0, 0
    weight_Q = quant_fn['weight_Q']
    # acc_Q = quant_fn['acc_Q']
    grad_Q = quant_fn['grad_Q']

    if lambda_fedprox > 0.0:
        W0 = {k: v.detach().clone() for k, v in model.named_parameters()}

    origin = copy.deepcopy(model)
    best_model = None
    min_epochs = 10
    min_val_loss = 10
    #
    # quant_error = 0
    for ep in range(epochs):

        for it, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = nn.CrossEntropyLoss()(model(x), y)
            if lambda_fedprox > 0.0:
                loss += lambda_fedprox * torch.sum(
                    (flatten(W0).cuda() - flatten(dict(model.named_parameters())).cuda()) ** 2)
            running_loss += loss.item() * y.shape[0]
            samples += y.shape[0]
            loss.backward()
            ## test
            if args.grad_clip:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_to
                )

            with torch.no_grad():
                for name, param in model.named_parameters():
                    param.grad.data = grad_Q(param.grad.data).data
                # for name, param in model.named_parameters():
                #     param.data = model.weight_acc[name]
            optimizer.step(server.control, model.control)
            # with torch.no_grad():
            #     for name, param in model.named_parameters():
            #         model.weight_acc[name] = acc_Q(param.data).data
            #         param.data = weight_Q(model.weight_acc[name]).data
            # print('---------------------------------------------------------')
            with torch.no_grad():
                for name, p in model.named_parameters():
                    # print(name, p.data.size())
                    #
                    # weight = p.data.clone()
                    p.data = weight_Q(p.data).data
                    #
                    # quant_error += torch.norm(weight - p.data, p = 2)
        val_loss = get_val_loss(model, val)
        if ep + 1 >= min_epochs and val_loss < min_val_loss:
            min_val_loss = val_loss
            best_model = copy.deepcopy(model)
        elif ep + 1 < min_epochs:
            best_model = copy.deepcopy(model)

    model_now = copy.deepcopy(best_model)


    temp = {}
    for name, param in model_now.named_parameters():
        temp[name] = param.data.clone()

    for name, param in origin.named_parameters():
        local_steps = epochs * len(loader)
        model_now.control[name] = model_now.control[name] - server.control[name] + (param.data - temp[name]) / (local_steps * args.lr)
        model_now.delta_y[name] = temp[name] - param.data
        model_now.delta_control[name] = model_now.control[name] - origin.control[name]

    # print('--------------------------------')
    # print(quant_error)
    #
    return {"model": model_now,"loss": running_loss / samples}
    # return {"loss": running_loss / samples, "quant_error": quant_error}

def get_val_loss(model, Val):
    model.eval()
    loss_function = nn.MSELoss().to(device)
    val_loss = []
    for (seq, label) in Val:
        with torch.no_grad():
            seq = seq.to(device)
            label = label.to(device)
            y_pred = model(seq)
            loss = nn.CrossEntropyLoss()(y_pred, label)
            val_loss.append(loss.item())

    return np.mean(val_loss)


## test
def train_op_ma(model, loader, optimizer, epochs,  quant_fn=None, moving_weight=0.1):
    model.train()
    running_loss, samples = 0.0, 0
    weight_Q = quant_fn['weight_Q']
    # acc_Q = quant_fn['acc_Q']
    grad_Q = quant_fn['grad_Q']

    grad_moving_avg = {}
    param_moving_avg = {}
    for name, param in model.named_parameters():
        grad_moving_avg[name] = torch.zeros_like(param)
        param_moving_avg[name] = torch.zeros_like(param)

    for ep in range(epochs):

        for it, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = nn.CrossEntropyLoss()(model(x), y)
            running_loss += loss.item() * y.shape[0]
            samples += y.shape[0]
            loss.backward()
            ## test
            if args.grad_clip:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_to
                )

            with torch.no_grad():
                for name, param in model.named_parameters():
                    if it == 0:
                        grad_moving_avg[name] = param.grad.data
                    else:
                        grad_moving_avg[name] = moving_weight * grad_Q(grad_moving_avg[name]).data + (1 - moving_weight) * grad_Q(param.grad.data).data
                    param.grad.data = grad_Q(grad_moving_avg[name].data).data
                    # param.grad.data = grad_Q(param.grad.data).data

            optimizer.step()

            with torch.no_grad():
                for name, p in model.named_parameters():
                    if it == 0:
                        param_moving_avg[name] = p.data
                    else:
                        param_moving_avg[name] = moving_weight * weight_Q(param_moving_avg[name]).data + (1 - moving_weight) * weight_Q(p.data).data
                    p.data = weight_Q(param_moving_avg[name].data).data
                    # p.data = weight_Q(p.data).data

    return {"loss": running_loss / samples}



def eval_op(model, loader):
    model.train()
    samples, correct, running_loss = 0, 0, 0.0

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)

            y_ = model(x)
            _, predicted = torch.max(y_.detach(), 1)
            loss = nn.CrossEntropyLoss()(y_, y).item()

            running_loss += loss * y.shape[0]
            samples += y.shape[0]
            correct += (predicted == y).sum().item()

    return {"accuracy": correct / samples, "loss": running_loss / samples}


def eval_op_ensemble(model, test_loader):
    model.eval()

    samples, correct, running_loss = 0, 0, 0.0

    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            x, y = x.to(device), y.to(device)

            y_ = model(x)
            _, predicted = torch.max(y_.detach(), 1)
            running_loss += nn.CrossEntropyLoss()(y_, y).item() * y_.shape[0]

            samples += y.shape[0]
            correct += (predicted == y).sum().item()
    test_acc = correct / samples
    test_loss = running_loss / samples

    return {"test_accuracy": test_acc, "test_loss": test_loss}


def reduce_average(target, sources):
    for name in target:
        target[name].data = torch.mean(torch.stack([source[name].data.detach() for source in sources]), dim=0).clone()


def reduce_median(target, sources):
    for name in target:
        target[name].data = torch.median(torch.stack([source[name].detach() for source in sources]),
                                         dim=0).values.clone()


def reduce_weighted(target, sources, weights):
    for name in target:
        target[name].data = torch.sum(weights * torch.stack([source[name].detach() for source in sources], dim=-1),
                                      dim=-1).clone()


def flatten(source):
    return torch.cat([value.flatten() for value in source.values()])


def parse_dict(d, args):
    for key, value in d.items():
        if type(value) == dict:
            parse_dict(value, args)
        else:
            args.__dict__.setdefault(key, value)


def moving_average(net1, net2, alpha=1):

    for param1, param2 in zip(net1.parameters(), net2.parameters()):
        param1.data *= alpha
        param1.data += param2.data * (1.0 - alpha)



class ScaffoldOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr, weight_decay):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(ScaffoldOptimizer, self).__init__(params, defaults)

    def step(self, server_controls, client_controls, closure=None):

        loss = None
        if closure is not None:
            loss = closure

        for group in self.param_groups:
            for p, c, ci in zip(group['params'], server_controls.values(), client_controls.values()):
                if p.grad is None:
                    continue
                dp = p.grad.data + c.data - ci.data
                p.data = p.data - dp.data * group['lr']

        return loss


def get_class_number(clients, n_class):
    client_class_num = np.zeros((len(clients), n_class))
    for client in clients:
        for x, bt_y in client.loader:
            for y in bt_y:
                client_class_num[client.id, y.item()] += 1

    return client_class_num


def generate_labels(number, class_num):
    labels = np.arange(number)
    proportions = class_num / class_num.sum()
    proportions = (np.cumsum(proportions) * number).astype(int)[:-1]
    labels_split = np.split(labels, proportions)
    for i in range(len(labels_split)):
        labels_split[i].fill(i)
    labels = np.concatenate(labels_split)
    np.random.shuffle(labels)
    return labels.astype(int)

def get_batch_weight(labels, class_client_weight):
    bs = labels.size
    num_clients = class_client_weight.shape[1]
    batch_weight = np.zeros((bs, num_clients))
    batch_weight[np.arange(bs), :] = class_client_weight[labels, :]
    return batch_weight