import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
from collections import OrderedDict
from functions.utils import Dataset, get_mdl_params, transform_param
from .training_methods import *
import time


def AdamGC(data_obj, lr, epoch=1, n_agents=10, opt=None):
    dim = 28
    device = torch.device('cuda:{}'.format(opt.gpu_id))
    input_list = data_obj.clnt_x
    label_list = data_obj.clnt_y

    sample_size = int(n_agents*opt.sample)
    f, global_f = get_model(opt)

    '''initialization for all f'''
    central_f_dict = global_f.state_dict()
    for i in range(n_agents):
        f[i].load_state_dict(central_f_dict)
    z_i = OrderedDict()
    for name in f[0].state_dict():
        z_i[name] = torch.zeros(f[0].state_dict()[name].shape).to(device)
    n_par = len(get_mdl_params([global_f])[0])
    init_par_list=get_mdl_params([global_f], n_par)[0]
    clnt_params_list=np.ones(n_agents).astype('float32').reshape(-1, 1) * init_par_list.reshape(1, -1)
    y = np.zeros((n_agents, n_par)).astype('float32')
    ycen = np.zeros((n_par)).astype('float32')

    grad_list = np.zeros((n_agents, n_par)).astype('float32')

    com_count = 0

    error_list = []
    acc_list = []
    iter_list = []
    lossfn = nn.CrossEntropyLoss()
    com_iter = 0
    iters = 0

    '''evaluate'''
    loss, acc = get_acc_loss(data_obj.tst_x, data_obj.tst_y, global_f, device, opt.dataset)
    acc_list.append(acc)
    iter_list.append(iters)
    error_list.append(loss)
    print("[AdamGC][%d agents]: " % n_agents, '\tIters:', iters, '\tLoss: %.2E' % error_list[-1], '\t\t\t\t\r', end = '')

    n_iter_per_epoch = int(np.ceil(input_list[0].shape[0]/opt.batch))

    optimizer  = torch.optim.SGD(global_f.parameters(), lr=lr, weight_decay=0)

    lr_g = opt.lr_g
    rounds = opt.rounds if not opt.count_rounds else 1000

    for t in range(int(rounds)):
        sample_list = random.sample( [i for i in range(n_agents)], sample_size)
        f_old = get_mdl_params([global_f], n_par)[0]
        delta_y_list = np.zeros((n_agents, n_par)).astype('float32')

        y_sample_list = random.sample(sample_list, int(sample_size*opt.y_sample))

        for i in sample_list:
            f[i], grad_list[i] = AdamGC_Local_train(f[i], ycen, y[i], lr, input_list[i], label_list[i], epoch, device, opt)
            curr_model_param = get_mdl_params([f[i]], n_par)[0]
            clnt_params_list[i] = curr_model_param
            if i in y_sample_list:
                with torch.no_grad():
                    y_i_new = 1/(epoch*n_iter_per_epoch)*grad_list[i]
                    delta_y_list[i] = y_i_new - y[i]
                    y[i] = y_i_new
        '''consensus'''
        with torch.no_grad():
            f_new = get_mdl_params([global_f], n_par)[0]
            for i in sample_list:
                f_new += lr_g*1/sample_size*(clnt_params_list[i] - f_old)

            '''update y'''
            ycen = ycen + np.mean(delta_y_list, axis = 0)
            global_f.load_state_dict(transform_param(global_f, f_new, device))
            for i in range(n_agents):
                f[i].load_state_dict(transform_param(global_f, f_new, device))
        lr = lr*opt.lr_decay
        iters += 1
        com_iter += 1
        com_count += 1

        if iters % 1 == 0:
            '''compute loss'''
            loss, acc = get_acc_loss(data_obj.tst_x, data_obj.tst_y, global_f, device, opt.dataset)
            acc_list.append(acc)
            iter_list.append(iters)
            error_list.append(loss)

        if iters % 1 == 0:
            print("[AdamGC][%d agents]: " % n_agents, '\tIters:', iters, '\tLoss: %.2E' % error_list[-1], "\tAcc: %.2f" % (100*acc_list[-1]), '\t\t\t\t\r', end = '')
        if acc_list[-1] >= opt.target_acc:
            print('\n', end = '')
            return error_list, acc_list, t+1
    print('\n', end = '')
    return error_list, acc_list, None


def AdamGC_Local_train(f_i, ycen, y_i, lr, trn_x, trn_y, epoch, device, opt):
    n_trn = trn_x.shape[0]

    for params in f_i.parameters():
        params.requires_grad = True
    trn_gen = torch.utils.data.DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=opt.dataset), batch_size=opt.batch, shuffle=True)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    weight_decay = 1e-8 if opt.dataset == 'TinyImage' else 1e-3

    optimizer = torch.optim.Adam(f_i.parameters(), lr=lr, amsgrad=True, betas=(0.9, 0), eps=1e-8, weight_decay=weight_decay)
    zi = torch.tensor(ycen - y_i, dtype=torch.float32, device=device)

    f_i.train(); f_i = f_i.to(device)
    n_iter_per_epoch = np.ceil(n_trn/opt.batch)

    total_grad_list = None
    f_i_params = dict(f_i.named_parameters())
    f_i_params_keys = f_i_params.keys()

    terminate = None
    if epoch < 1:
        terminate = epoch*n_iter_per_epoch
    start = time.time()
    for k in range(int(max(1, epoch))):
        trn_gen_iter = trn_gen.__iter__()
        for i in range(int(np.ceil(n_trn/opt.batch))):
            batch_x, batch_y = trn_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            
            y_pred = f_i(batch_x)
            loss_alg = loss_fn(y_pred, batch_y.reshape(-1).long())
            loss_alg = loss_alg / list(batch_y.size())[0]

            local_par_list = None
            f_i_params = dict(f_i.named_parameters())
            for name, param in f_i.state_dict().items():
                if name not in f_i_params_keys:
                    add_param = param.reshape(-1)
                else:
                    add_param = f_i_params[name].reshape(-1)
                if not isinstance(local_par_list, torch.Tensor):
                    local_par_list = add_param
                else:
                    local_par_list = torch.cat((local_par_list, add_param), 0)
            
            loss_algo = torch.sum(local_par_list * zi)
            loss = loss_alg + loss_algo
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(parameters=f_i.parameters(), max_norm=10) # Clip gradients

            iter_grad_list = None
            f_i_params = dict(f_i.named_parameters())
            for name, param in f_i.state_dict().items():
                if name not in f_i_params_keys:
                    add_param = param
                else:
                    add_param = f_i_params[name]
                grad = add_param.grad
                if grad is None:
                    add_grad = torch.zeros(param.shape).reshape(-1).to(device)
                else:
                    add_grad = grad.reshape(-1)
                if not isinstance(iter_grad_list, torch.Tensor):
                    iter_grad_list = add_grad
                else:
                    iter_grad_list = torch.cat((iter_grad_list, add_grad), 0)

            if not isinstance(total_grad_list, torch.Tensor):
                total_grad_list = iter_grad_list
            else:
                total_grad_list += iter_grad_list
            
            optimizer.step()

            if terminate is not None and i >= terminate:
                break



    for params in f_i.parameters():
        params.requires_grad = False
    f_i.eval()
    return f_i, (total_grad_list - epoch*n_iter_per_epoch*zi).cpu()




def get_acc_loss(data_x, data_y, model, device, dataset_name):
    acc_overall = 0; loss_overall = 0;
    loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    batch_size = min(6000, data_x.shape[0])
    n_tst = data_x.shape[0]
    tst_gen = torch.utils.data.DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=batch_size, shuffle=False)
    model.eval(); model = model.to(device)
    with torch.no_grad():
        tst_gen_iter = tst_gen.__iter__()
        for i in range(int(np.ceil(n_tst/batch_size))):
            batch_x, batch_y = tst_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            y_pred = model(batch_x)
            
            loss = loss_fn(y_pred, batch_y.reshape(-1).long())
            loss_overall += loss.item()
            # Accuracy calculation
            y_pred = y_pred.cpu().numpy()            
            y_pred = np.argmax(y_pred, axis=1).reshape(-1)
            batch_y = batch_y.cpu().numpy().reshape(-1).astype(np.int32)
            batch_correct = np.sum(y_pred == batch_y)
            acc_overall += batch_correct
    
    loss_overall /= n_tst
        
    model.train()
    return loss_overall, acc_overall / n_tst


