import os
import argparse
from copy import deepcopy
import torch
import torch.optim as optim
from torch.utils.tensorboard.writer import SummaryWriter
import torch.nn.functional as F

if __package__ is None:
    import sys
    from os import path
    sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))

    from network.get_network import GetNetwork
    from data.pacs_dataset import PACS_FedDG, get_transforms, ra_transforms, base_transforms
    from utils.buffer import ReplayBuffer
    from utils.augmentation import build_augmentation
    from utils.classification_metric import Classification
    from utils.loss import NonSaturatingLoss
    from utils.log_utils import *
    from utils.fed_merge import Cal_Weight_Dict, FedAvg, FedUpdate
    from utils.teachaugment import TeachAugment
    from utils.trainval_func import site_evaluation, site_train, site_train_ta, site_train_swad, GetFedModel, SaveCheckPoint

def get_argparse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default='pacs', choices=['pacs'], help='Name of dataset')
    parser.add_argument("--model", type=str, default='resnet18',
                        choices=['resnet18', 'resnet50'], help='model name')
    parser.add_argument("--test_domain", type=str, default='p',
                        choices=['p', 'a', 'c', 's'], help='the domain name for testing')
    parser.add_argument('--num_classes', help='number of classes default 7', type=int, default=7)
    parser.add_argument('--batch_size', help='batch_size', type=int, default=16)
    parser.add_argument('--local_epochs', help='epochs number', type=int, default=5)
    parser.add_argument('--comm', help='epochs number', type=int, default=40)
    parser.add_argument('--lr', help='learning rate', type=float, default=0.001)
    parser.add_argument("--lr_policy", type=str, default='step', choices=['step'],
                        help="learning rate scheduler policy")
    parser.add_argument('--note', help='note of experimental settings', type=str, default='fedavg')
    parser.add_argument('--display', help='display in controller', action='store_true')

    parser.add_argument('--sam', action='store_true')
    parser.add_argument('--rho', help='Rho parameter for SAM and ASAM minimizers', type=float, default=0.02)
    parser.add_argument('--eta', help='Eta parameter for SAM and ASAM minimizers', type=float, default=0)

    parser.add_argument('--lwf', action='store_true')
    parser.add_argument('--T', help='temperature for ce', type=float, default=2.0)
    parser.add_argument('--lamda', type=float, default=1.0)

    parser.add_argument('--ta', action='store_true')
    parser.add_argument('--aug', type=str, default='standard', choices=['standard', 'ra', 'aa', 'cutout'])
    parser.add_argument('--sampling_freq', default=10, type=int, help='sampling augmentation frequency')
    parser.add_argument('--chunks', help='the number of data splits', type=int, default=4)
    parser.add_argument('--ema', type=float, default=0.0)
    parser.add_argument('--c_reg_coef', default=10, type=float, help='coefficient of the color regularization')
    parser.add_argument('--g_scale', default=0.125, type=float, help='the search range of the magnitude of geometric augmantation')
    parser.add_argument('--c_scale', default=0.2, type=float, help='the search range of the magnitude of color augmantation')
    parser.add_argument('--n_inner', default=10, type=int, help='the number of iterations for inner loop (i.e., updating classifier)')

    parser.add_argument('--swad', action='store_true')
    parser.add_argument('--n_converge', default=3, type=int)
    parser.add_argument('--n_tolerance', default=6, type=int)

    parser.add_argument('--flwf', action='store_true')
    parser.add_argument('--fta', action='store_true')

    parser.add_argument('--flat', help='get flatness online', action='store_true')

    return parser.parse_args()
 
def main():
    '''log part'''
    file_name = 'fedavg_'+os.path.split(__file__)[1].replace('.py', '')
    args = get_argparse()
    log_dir, tensorboard_dir = Gen_Log_Dir(args, file_name=file_name)
    log_ten = SummaryWriter(log_dir=tensorboard_dir)
    log_file = Get_Logger(file_name=log_dir + 'train.log', display=args.display)
    Save_Hyperparameter(log_dir, args)
    
    '''dataset and dataloader'''
    if args.ta or args.fta:
        base_aug, normalizer = get_transforms()
    base_transforms(args.aug)

    dataobj = PACS_FedDG(test_domain=args.test_domain, batch_size=args.batch_size)
    dataloader_dict, dataset_dict = dataobj.GetData()
    
    '''model'''
    metric = Classification()
    global_model, model_dict, optimizer_dict, scheduler_dict = GetFedModel(args, args.num_classes)
    weight_dict = Cal_Weight_Dict(dataset_dict, site_list=dataobj.train_domain_list)
    FedUpdate(model_dict, global_model)

    if args.ta or args.fta:
        avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: args.ema * averaged_model_parameter + (1 - args.ema) * model_parameter
        ema_model = optim.swa_utils.AveragedModel(deepcopy(global_model), avg_fn=avg_fn, use_buffers=True)
        for ema_p in ema_model.parameters():
            ema_p.requires_grad_(False)
        ema_model.train()

        rbuffer = ReplayBuffer(0.9)
        trainable_aug_global = build_augmentation(
            args.num_classes, args.g_scale, args.c_scale, args.c_reg_coef, normalizer, rbuffer, args.chunks).cuda()
        trainable_aug_dict = {domain_name: build_augmentation(
            args.num_classes, args.g_scale, args.c_scale, args.c_reg_coef, normalizer, rbuffer, args.chunks).cuda()
                              for domain_name in dataobj.train_domain_list}
        # buffer_dict = {domain_name: ReplayBuffer(0.9) for domain_name in dataobj.train_domain_list}
        # trainable_aug_dict = {domain_name: build_augmentation(
        #     args.num_classes, 0.5, 0.8, 10, normalizer, buffer_dict[domain_name], 8).cuda()
        #                       for domain_name in dataobj.train_domain_list}
        base_aug = torch.nn.Sequential(*base_aug).cuda()
        optim_aug_dict = {domain_name: optim.AdamW(
            trainable_aug_dict[domain_name].parameters(), lr=1e-3, weight_decay=1e-2)
            for domain_name in dataobj.train_domain_list}
        adv_criterion = NonSaturatingLoss(0.1)
        # for name, param in trainable_aug_global.named_parameters():
        #     print(name, param.requires_grad)
        # exit()

    best_val = 0.
    if args.swad:
        swa_model_dict = deepcopy(model_dict)
    for i in range(args.comm+1):
        FedUpdate(model_dict, global_model)
        if args.ta or args.fta:
            FedUpdate(trainable_aug_dict, trainable_aug_global)
            for domain_name in dataobj.train_domain_list:
                model_dict[domain_name].train()
            objective_dict = {domain_name: TeachAugment(
                model_dict[domain_name], deepcopy(ema_model), trainable_aug_dict[domain_name], adv_criterion, 0,
                base_aug, normalizer).cuda() for domain_name in dataobj.train_domain_list}

        for domain_name in dataobj.train_domain_list:
            if args.ta or args.fta:
                if args.swad:
                    swa_model_dict[domain_name] = site_train_swad(
                        i, domain_name, args, objective_dict[domain_name], optimizer_dict[domain_name],
                        optim_aug_dict[domain_name], scheduler_dict[domain_name], dataloader_dict[domain_name]['train'],
                        dataloader_dict[domain_name]['val'], log_ten, metric)
                else:
                    site_train_ta(i, domain_name, args, objective_dict[domain_name], optimizer_dict[domain_name],
                                  optim_aug_dict[domain_name], scheduler_dict[domain_name],
                                  dataloader_dict[domain_name]['train'], log_ten, metric)
                    # if i % args.sampling_freq == 0:
                    #     buffer_dict[domain_name].store(trainable_aug_dict[domain_name].get_augmentation_model())
            else:
                site_train(i, domain_name, args, model_dict[domain_name], optimizer_dict[domain_name],
                           scheduler_dict[domain_name], dataloader_dict[domain_name]['train'], log_ten, metric)
            
            site_evaluation(i, domain_name, args, model_dict[domain_name], dataloader_dict[domain_name]['val'], log_file, log_ten, metric, note='before_fed')
            site_evaluation(i, domain_name+'_val', args, model_dict[domain_name], dataobj.val_dataloader, log_file, log_ten, metric, note='before_fed')
            if args.swad:
                site_evaluation(i, domain_name, args, swa_model_dict[domain_name], dataloader_dict[domain_name]['val'], log_file, log_ten, metric, note='swad')
                site_evaluation(i, domain_name+'_val', args, swa_model_dict[domain_name], dataobj.val_dataloader, log_file, log_ten, metric, note='swad')
            # if args.ta or args.fta:
            #     site_evaluation(i, domain_name+'_ema', args, ema_model, dataloader_dict[domain_name]['val'], log_file, log_ten, metric, note='before_fed')
        if args.swad:
            FedAvg(swa_model_dict, weight_dict, global_model)
        else:
            FedAvg(model_dict, weight_dict, global_model)

        fed_val = 0.
        for domain_name in dataobj.train_domain_list:
            results_dict = site_evaluation(i, domain_name, args, global_model, dataloader_dict[domain_name]['val'], log_file, log_ten, metric)
            fed_val+= results_dict['acc']*weight_dict[domain_name]
        site_evaluation(i, 'val', args, global_model, dataobj.val_dataloader, log_file, log_ten, metric)
        # val 结果
        if fed_val >= best_val:
            best_val = fed_val
            SaveCheckPoint(args, global_model, args.comm, os.path.join(log_dir, 'checkpoints'), note='best_val_model')
            for domain_name in dataobj.train_domain_list:
                SaveCheckPoint(args, model_dict[domain_name], args.comm, os.path.join(log_dir, 'checkpoints'), note=f'best_val_{domain_name}_model')

            log_file.info(f'Model saved! Best Val Acc: {best_val*100:.2f}%')
        site_evaluation(i, args.test_domain, args, global_model, dataloader_dict[args.test_domain]['test'], log_file, log_ten, metric, note='test_domain')

        if args.ta or args.fta:
            # site_evaluation(i, 'val_ema', args, ema_model, dataobj.val_dataloader, log_file, log_ten, metric)
            # site_evaluation(i, args.test_domain+'_ema', args, ema_model, dataloader_dict[args.test_domain]['test'], log_file, log_ten, metric, note='test_domain')

            FedAvg(trainable_aug_dict, weight_dict, trainable_aug_global)
            if i % args.sampling_freq == 0:
                rbuffer.store(trainable_aug_global.get_augmentation_model())
            print(f'store augmentation (buffer length: {len(rbuffer)})')
            # print(f'store augmentation (buffer length: {len(buffer_dict[domain_name])})')
            ema_model.update_parameters(deepcopy(global_model))
            ema_model.train()

    SaveCheckPoint(args, global_model, args.comm, os.path.join(log_dir, 'checkpoints'), note='last_model')
    for domain_name in dataobj.train_domain_list:
        SaveCheckPoint(args, model_dict[domain_name], args.comm, os.path.join(log_dir, 'checkpoints'), note=f'last_{domain_name}_model')

if __name__ == '__main__':
    main()