import sys, os
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)

import torch
from torch import nn, optim
import time
import copy
from nets.models import ResNet_20 
import argparse
import numpy as np
import torchvision
import torchvision.transforms as transforms
from utils import data_utils, AverageMeter
from tensorboardX import SummaryWriter
import torchvision.models as models
import time

class ClientOpt():
    def __init__(self,args, loss_fun, device):
        self.args = args
        self.loss_fun = loss_fun
        self.device = device
        self.epoch = args.E


    def train(self, model, train_loader, optimizer, server_model, global_grad=None, local_grad=None ,shift_label=False ):
        start_time = int(round(time.time()*1000))
        num_data = 0
        correct = 0
        loss_all = 0
                
        if self.args.weighting == 'exp' :
            lr_start = int(round(time.time()*1000))
            alpha = log_loss(model, train_loader, self.device)
            lr_end = int(round(time.time()*1000))
            lr_time = lr_end - lr_start
        else:
            alpha = None
        
        model.train()
        for e in range(self.epoch):
            for i,(x,y) in enumerate(train_loader):
                optimizer.zero_grad()
                num_data += y.size(0)
                x = x.to(self.device).float()
                y = y.to(self.device).long()
                output = model(x)
                loss = self.loss_fun(output, y)

                loss.backward()
                loss_all += loss.item()
                optimizer.step()
                pred = output.data.max(1)[1]
                correct += pred.eq(y.view(-1)).sum().item()

        if self.args.weighting == 'exp':         
            lr_start = int(round(time.time()*1000))
            F_t_plus_1 = log_loss(model,train_loader, self.device) 
            alpha = abs(F_t_plus_1 - alpha)
            lr_end = int(round(time.time()*1000))
            lr_time += (lr_end - lr_start)
            
        
        end_time = int(round(time.time()*1000))
        train_time = end_time - start_time
        return loss_all/len(train_loader), correct/num_data, alpha


def test(model, test_loader, loss_fun, device):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.to(device).float()
        target = target.to(device).long()
        output = model(data)
        test_loss += loss_fun(output, target).item()
        pred = output.data.max(1)[1]
        correct += pred.eq(target.view(-1)).sum().item()
    
    return test_loss/len(test_loader), correct /len(test_loader.dataset)

################# Key Function ########################
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    return [data, target]

def log_loss(model,data_loader,device):
    model.eval()
    loss = 0
    collect_freq = 2
    for i,(x,y) in enumerate(data_loader):
        with torch.no_grad():
            x = x.to(device).float()
            y = y.to(device).long()
            if (i+1)%collect_freq == 0:
                x = torch.cat((x,x_prev),dim=0)
                y = torch.cat((y,y_prev),dim=0)
            else:
                x_prev = x
                y_prev = y
                continue
            output = model(x) 
            loss -= torch.nn.functional.cross_entropy(output, y, reduction = 'mean')
    return loss.item() /len(data_loader) * collect_freq


class ServerOpt():
    def __init__(self,args,server_model, loss_fun, device):
        self.args = args
        self.device = device
        self.loss_fun = loss_fun
                
    def aggregation_weight(self,client_alpha):
        if self.args.weighting == 'exp':
            client_alpha = -np.array(client_alpha) * 1/self.args.temperature
            max_alpha = client_alpha.max()
            log_weight = client_alpha - (max_alpha + np.log((np.exp(client_alpha-max_alpha)).sum()))
            weights = np.exp(log_weight)
        return weights

    def communication(self, server_model, models, client_weights):
        with torch.no_grad():
            # aggregate params
            for key in server_model.state_dict().keys():
                # num_batches_tracked is a non trainable LongTensor and
                # num_batches_tracked are the same for all clients for the given datasets
                if 'num_batches_tracked' in key:
                    server_model.state_dict()[key].data.copy_(models[0].state_dict()[key])
                else:
                    # aggregate from local models
                    temp = torch.zeros_like(server_model.state_dict()[key])
                    for client_idx in range(len(client_weights)):
                        temp += client_weights[client_idx] * models[client_idx].state_dict()[key]
                    server_model.state_dict()[key].data.copy_(temp)
                # distribute to local models 
                for client_idx in range(len(client_weights)):
                    models[client_idx].state_dict()[key].data.copy_(server_model.state_dict()[key])    
                            
        return server_model, models


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    

    print('Device:', device)
    parser = argparse.ArgumentParser()
    parser.add_argument('--log', action='store_true', help ='whether to make a log')
    parser.add_argument('--test', action='store_true', help ='test the pretrained model')
    parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
    parser.add_argument('--batch', type = int, default= 128, help ='batch size')
    parser.add_argument('--iters', type = int, default=200, help = 'iterations for communication')
    parser.add_argument('--wk_iters', type = int, default=1, help = 'optimization iters in local worker between communication')
    parser.add_argument('--mode', type = str, default='fedavg', help='fedavg')
    parser.add_argument('--save_path', type = str, default='../checkpoint/', help='path to save the checkpoint')
    parser.add_argument('--local_iter', type = int, default= 10, help ='number of local iterations per epoch')
    parser.add_argument('--name', type = str, default= 'base', help ='name of the experiments')
    parser.add_argument('--E', type = int, default= 2, help ='number of local epochs')
    parser.add_argument('--temperature', default=0.5, type=float, help='temperature parameters for Exp-a')
    parser.add_argument('--num_client', default=10, type=int, help='number of client for cifar')
    parser.add_argument('--seed', default=1, type=int, help='random seed for experiments')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
    parser.add_argument('-w', '--weighting', type = str, default= 'propoto', help ='aggregation type: [propoto, exp]')
    
    args = parser.parse_args()
    exp_folder = str(args.iters) +'_' + str(args.E)  + '_' + args.weighting + '_' + str(args.seed) 

    args.save_path = os.path.join(args.save_path, args.mode, args.name ,exp_folder)
    
    if args.log:
        log_path = os.path.join('./logs',args.mode , args.name, exp_folder)
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        logfile = open(os.path.join(log_path,'{}.log'.format(args.mode)), 'a')
        logfile.write('==={}===\n'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
        logfile.write('===Setting===\n')
        logfile.write('    lr: {}\n'.format(args.lr))
        logfile.write('    batch: {}\n'.format(args.batch))
        logfile.write('    local epoch: {}\n'.format(args.E))
        logfile.write('    glboal iters: {}\n'.format(args.iters))
        logfile.write('    wk_iters: {}\n'.format(args.wk_iters))
        if args.weighting == 'exp':
            logfile.write('    lr temperature: {}\n'.format(args.temperature))
        writer = SummaryWriter(log_path)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    SAVE_PATH = os.path.join(args.save_path, '{}'.format(args.mode))

    # prepare CIFAR validation and test sets and model
    val_loaders, test_loaders = data_utils.prepare_cifar_data(args)
    server_model = ResNet_20(num_classes=10).to(device)
    datasets = []
    for i in range(args.num_client):
        datasets.append('cifar_{}'.format(i))
    total_params = sum(p.numel() for p in server_model.parameters())
    loss_fun =  nn.CrossEntropyLoss()
    
    # federated setting
    client_weights = [1/args.num_client for i in range(args.num_client)] #[1/12, 1/12, 1/12, 1/12, 1/12, 7/12]  #
    client_alpha = [1/args.num_client for i in range(args.num_client)]

    # creat local models 
    models = [copy.deepcopy(server_model).to(device) for I in range(args.num_client)]
    
    # initialize FedOpt
    server_update = ServerOpt(args, server_model, loss_fun, device)
    client_update = ClientOpt(args, loss_fun, device)

    if args.test:
        print('Loading checkpoints...')
        checkpoint = torch.load('../snapshots/'+args.dataset+'/{}/{}'.format(args.name.lower(),args.mode.lower()))
        server_model.load_state_dict(checkpoint['server_model'])
        for test_idx, test_loader in enumerate(test_loaders):
            _, test_acc = test(server_model, test_loader, loss_fun, device)
            print(' {:<11s}| Test  Acc: {:.4f}'.format(datasets[test_idx], test_acc))
        exit(0)

    resume_iter = 0
    optimizers = [optim.SGD(params=models[idx].parameters(), lr=args.lr) for idx in range(args.num_client)]
    best_acc = 0
    
    iter_time = AverageMeter('Time', ':6.3f')
    end = time.time()

    # start training
    for a_iter in range(resume_iter, args.iters):
        # resample clients 
        imbalance_list = [1.0,0.5,0.1,0.05,0.01,0.005,0.001]
        imbalace_factor = np.take(imbalance_list,np.random.randint(0,len(imbalance_list),args.num_client))
        train_loaders = data_utils.shuffle_cifar_data(args,imbalace_factor)
        
        for wi in range(args.wk_iters):
            print("============ Train Global Iter. {} ============".format(wi + a_iter * args.wk_iters))
            if args.log: logfile.write("============ Train epoch {} ============\n".format(wi + a_iter * args.wk_iters)) 
            for client_idx in range(args.num_client):
                model, train_loader, optimizer = models[client_idx], train_loaders[client_idx], optimizers[client_idx]
                
                _,_,client_alpha[client_idx] = client_update.train(model, train_loader, optimizer, server_model,)

        # aggregation
        if args.weighting == 'exp' and a_iter > 0:
            client_weights = server_update.aggregation_weight(client_alpha)
        if args.log and a_iter > 0:
            for ci, weight in enumerate(client_weights):
                logfile.write('{:.4f} '.format(weight))
                writer.add_scalar('client_{}'.format(ci), weight, a_iter)
            logfile.write('\n')
        print(client_weights)
       
        # syncronization 
        server_model, models = server_update.communication(server_model, models ,client_weights)
     
        # start validation
        val_avg_acc = 0
        val_avg_loss = 0
        if len(val_loaders) == 1:
               val_loss, val_avg_acc = test(server_model, val_loaders[0], loss_fun, device)
               print(' Val  Loss: {:.4f} | Val  Acc: {:.4f}'.format(val_loss, val_avg_acc))
               if args.log:
                    logfile.write('  Val  Loss: {:.4f} | Val  Acc: {:.4f}\n'.format(val_loss, val_avg_acc))
        else:
            for val_idx, val_loader in enumerate(val_loaders):
                val_loss, val_acc = test(models[val_idx], val_loader, loss_fun, device)
                print(' {:<11s}| Val  Loss: {:.4f} | Val  Acc: {:.4f}'.format(datasets[val_idx], val_loss, val_acc))
                val_avg_acc += val_acc
                val_avg_loss += val_loss
                if args.log:
                    logfile.write(' {:<11s}| Val  Loss: {:.4f} | Val  Acc: {:.4f}\n'.format(datasets[val_idx], val_loss, val_acc))
        
        val_avg_acc = val_avg_acc/len(val_loaders)
        val_avg_loss = val_avg_loss/len(val_loaders)
        print(' Leraning rate:{:.4f}. Average val accuracy:{:.4f}'.format(optimizer.param_groups[-1]['lr'], val_avg_acc))
        if args.log:
                logfile.write(' Leraning rate:{:.4f}. Average val accuracy:{:.4f}\n'.format(optimizer.param_groups[-1]['lr'],val_avg_acc))
        if args.log:
                writer.add_scalar('val_loss', val_avg_loss, a_iter)
                
        # Record best
        if val_avg_acc > best_acc:
            best_acc = val_avg_acc
            if args.log:
                writer.add_scalar('best_val_acc', best_acc, a_iter)
            # Save best checkpoint
            print(' Saving checkpoints to {}...'.format(SAVE_PATH))
            if args.log:
                logfile.write(' Saving the local and server checkpoint to {}...\n'.format(SAVE_PATH))
            torch.save({
                'server_model': server_model.state_dict(),
            }, SAVE_PATH)


        print(' Best average val accuracy:{:.4f}\n'.format(best_acc))
        if args.log:
            logfile.write(' Best average val accuracy:{:.4f}\n'.format(best_acc))
        iter_time.update(time.time() - end)
        end = time.time()
        print('Iter Time {iter_time.val:.3f} ({iter_time.avg:.3f})\t'.format(iter_time=iter_time))

    # Final testing 
    print('Loading checkpoints...')
    checkpoint = torch.load(SAVE_PATH)
    server_model.load_state_dict(checkpoint['server_model'])
    test_avg_acc = 0
    for test_idx, test_loader in enumerate(test_loaders):
        _, test_acc = test(server_model, test_loader, loss_fun, device)
        test_avg_acc += test_acc
        print(' {:<11s}| Test  Acc: {:.4f}'.format(datasets[test_idx], test_acc))
        if args.log:
                logfile.write(' {:<11s}| Test  Acc: {:.4f}\n'.format(datasets[test_idx], test_acc))
    test_avg_acc = test_avg_acc/len(test_loaders)
    print('Average Test accuracy:{:.4f}'.format(test_avg_acc))
    if args.log:
        logfile.write('Average Test accuracy:{:.4f}\n'.format(test_avg_acc))


    if args.log:
        logfile.flush()
        logfile.close()


