from utils.options import args_parser
import torch
from models.resnet import resnet8, resnet10
from utils.util import set_logger, setup_seed
import copy
import os
from tqdm import tqdm
from utils.get_dataset import get_dataset
logger = None
n, total_data = 0, 0
os.environ['ATTN_PRECISION'] = "fp16"
def log_info(message):
    print(message)
    logger.info(message)

def log_parameter(args):
    log_info(f'Communication Rounds: {args.epochs}')
    log_info(f'Client Number: {args.num_users}')
    log_info(f'Local Epochs: {args.local_ep}')
    log_info(f'Local Batch Size: {args.local_bs}')
    log_info(f'noniid: {args.noniid}')
    log_info(f'alpha: {args.alpha}')
    log_info(f'seed: {args.seed}')

def create_model(args, device):
    if args.dataset == "cifar-10":
        args.num_classes = 10
        return resnet8(num_labels=args.num_classes).to(device)
    elif args.dataset == 'tiny-imagenet':
        args.num_classes = 200
        return resnet10(num_labels=args.num_classes).to(device)
    elif args.dataset == "cifar-100":
        args.num_classes = 100
        return resnet10(num_labels=args.num_classes).to(device)
    elif args.dataset == 'cinic-10':
        args.num_classes = 10
        return resnet10(num_labels=args.num_classes).to(device)
    else:
        raise NotImplementedError

def train(args):
    global total_data
    from src.client import Client
    from src.server import Server

    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    local_model = create_model(args, device)
    log_info(f'Model Structure: {local_model}')
    
    gen_data_dir = os.path.join(args.gen_data_dir, f'{args.dataset}', f'{args.alpha}', f'{args.seed}')
    os.makedirs(gen_data_dir, exist_ok=True)
    for file in os.listdir(gen_data_dir):
        file_path = os.path.join(gen_data_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)

    train_dataset, test_dataset, train_user_groups = get_dataset(args)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)

    clients = [Client(device=device, local_model=copy.deepcopy(local_model), train_dataset=train_dataset,
                               test_dataset=test_dataset, train_idxs=train_user_groups[idx],
                               args=args, logger=logger) for idx in range(args.num_users)]
    server = Server(device=device, clients=clients, global_model=copy.deepcopy(local_model), args=args, logger=logger, 
                    test_loader=test_loader)

    avg_keys = []
    for name in local_model.state_dict():
        avg_keys.append(name)

    test_acc_list = []
    best_acc = 0
    for epoch in tqdm(range(args.epochs)):
        log_info(f'Start Training round: {epoch}')
        for i, client in enumerate(clients):
            client.train(epoch, server.global_model)

        server.aggreate_model(avg_keys)
        server.supply_data(epoch)
        server.sync_model(avg_keys)

        test_acc_avg, _ = server.evaluate()
        test_acc_list.append(test_acc_avg)
        best_acc = max(best_acc, test_acc_avg)
        log_info(f'Epoch: {epoch}: Test Avarage Accuracy:{test_acc_avg:.6f}')
        log_info(f"Epoch {epoch}: Best Accuracy: {best_acc:.6f}")
        

def init_args():
    global logger
    global n
    logFileName = args.log_dir + f'FedDM_{args.dataset}_{args.train_num}_alpha({args.alpha})_{args.mode}_{args.max_supply_num}_{args.supply_alpha}_{args.seed}.log'
    logger = set_logger(logFileName)
    n = args.num_users
    os.environ['PIPELINE_DEVICE'] = f'cuda:{args.gpu}'
    args.device = f'cuda:{args.gpu}'
    log_parameter(args)

# 1. 检查logFile位置
# 2. 检查每个参数
if __name__ == "__main__":
    args = args_parser()
    init_args()
    setup_seed(args.seed)
    train(args)
