from FedNovel_ours import FedNovel_model
from ResNet import *
import torch
import copy
import random
import os.path as osp
import os
from myNetwork import network
from Fed_utils import * 
from iCIFAR100 import *
from mini_imagenet import *
from tiny_imagenet import *
from option import args_parser
from sampler import Dirichlet_sampler


args = args_parser()
torch.set_num_threads(10)

## training log
output_dir = osp.join('./training_log', args.method, 'seed' + str(args.seed))
if not osp.exists(output_dir):
    os.system('mkdir -p ' + output_dir)
if not osp.exists(output_dir):
    os.mkdir(output_dir)
os.makedirs('saved_model', exist_ok=True)
args.out_file = osp.join(output_dir, 'log_tar_' + args.dataset + '_'.join(str(i) for i in args.task_classes) + '.txt')
log_print('method_{}, task_size{}, learning_rate_{}'.format(args.method, '_'.join(str(i) for i in args.task_classes), args.learning_rate), args.out_file)

## parameters for learning, no need for old clients
if args.dataset == 'cifar100' or args.dataset == 'tiny_imagenet':
    feature_extractor = resnet18_cbam()
else:
    feature_extractor = resnet34_cbam()

num_clients = args.num_clients
models = []

## seed settings
setup_seed(args.seed)

## model settings
model_g = network(args.task_classes[0], feature_extractor, args)
model_g = model_to_device(model_g, False, args.device)
model_old = None

train_transform = transforms.Compose([transforms.RandomCrop((args.img_size, args.img_size), padding=4),
                                    transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.ColorJitter(brightness=0.24705882352941178),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
test_transform = transforms.Compose([transforms.Resize(args.img_size), transforms.ToTensor(), 
                                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

if args.dataset == 'cifar100':
    train_dataset = iCIFAR100('dataset', transform=train_transform, download=True)
    test_dataset = iCIFAR100('dataset', test_transform=test_transform, train=False, download=True)

elif args.dataset == 'tiny_imagenet':
    train_dataset = Tiny_Imagenet('./dataset/tiny-imagenet-200', train_transform=train_transform, test_transform=test_transform)
    train_dataset.get_data()
    test_dataset = train_dataset

else:
    train_dataset = Mini_Imagenet('./dataset/train', train_transform=train_transform, test_transform=test_transform)
    train_dataset.get_data()
    test_dataset = train_dataset

for i in range(num_clients):
    model_temp = FedNovel_model(args, args.task_classes[0], feature_extractor, args.batch_size, args.task_classes,
                 args.epochs_local, args.learning_rate, train_dataset, args.device)
    models.append(model_temp)

classes_learned = args.task_classes[0]
old_task_id = -1
global_centers = []
best_k = 0
for ep_g in range(args.epochs_global):
    task_id = ep_g // args.round_per_task

    if (not args.source) and ep_g < args.round_per_task:
        old_task_id = task_id
        continue

    if task_id != old_task_id:
        if old_task_id == -1:
            cur_classes = [i for i in range(args.task_classes[0])]
        else:
            cur_classes = [i for i in range(sum(args.task_classes[:task_id]), sum(args.task_classes[:task_id+1]))]
            # cur_classes = [i for i in range(sum(args.task_classes[:task_id]), sum(args.task_classes[:task_id])+args.test_cls_num) ]
        clients_data_dict = Dirichlet_sampler(args, train_dataset, cur_classes)
        # print(cur_classes, clients_data_dict)

    if task_id != old_task_id and old_task_id != -1:
        if args.source:
            torch.save(model_g.state_dict(), './saved_model/{}_tasks{}.pt'.format(args.dataset, '_'.join(str(i) for i in args.task_classes)))
            break
        elif task_id == 1:
            w_g = torch.load('./saved_model/{}_tasks{}.pt'.format(args.dataset, '_'.join(str(i) for i in args.task_classes)))
            model_g.load_state_dict(w_g)
            _, __, acc_global = model_global_eval(model_g, test_dataset, old_task_id, args)
            log_print('loaded trained base model accuracy = {:.2f}%\n'.format(acc_global), args.out_file)

        model_old = copy.deepcopy(model_g)
        if args.method == 'one_stage':
            best_k = 20
        else:
            best_k = 0
        
    log_print('federated global round: {}, task_id: {}, best_k: {}'.format(ep_g, task_id, best_k), args.out_file)

    w_local = []
    local_center_pool, local_dis_pool = [], []
    if task_id != old_task_id and old_task_id != -1:
        clients_index = range(num_clients)
        log_print('all clients are selected to prepare potential local prototypes', args.out_file)
        log_print('{}'.format(clients_index), args.out_file)
    else:
        clients_index = random.sample(range(num_clients), args.local_clients)
        log_print('select part of clients to conduct local training', args.out_file)
        log_print('{}'.format(clients_index), args.out_file)

    print('current fc num: {}'.format(model_g.fc.weight.data.size(0)))

    for c in clients_index:
        local_model, local_centers = local_train(args, models, c, model_g, task_id, old_task_id, model_old, ep_g, clients_data_dict[c])
        w_local.append(local_model)
        if len(local_centers) > 0:
            local_center_pool.append(local_centers)
        
        if task_id > 0:
            models[c].args.ema_alpha -= models[c].args.ema_decay

    log_print('federated aggregation...', args.out_file)
    w_g_new, global_centers = FedAvg(w_local, task_id, local_center_pool, model_g, args.device, best_k, args)

    if w_g_new != None:
        model_g.load_state_dict(w_g_new)
        if len(global_centers) > 0:
            print('incremental learning')
            classes_learned = model_old.fc.weight.data.size(0) + len(global_centers)  

            model_g.Incremental_learning(args, classes_learned, global_centers)
    else:
        print('incremental learning')
        best_k = len(global_centers)
        classes_learned = model_old.fc.weight.data.size(0) + len(global_centers)  
        model_g.Incremental_learning(args, classes_learned, global_centers)

    model_g = model_to_device(model_g, False, args.device)

    acc_known, acc_novel, acc_global = model_global_eval(model_g, test_dataset, task_id, args)
    log_str = 'Task {}, Round {} Known Acc {:.2f}%, Novel Acc {:.2f}%, All Acc {:.2f}%\n'.format(task_id, ep_g, acc_known, acc_novel, acc_global)
    log_print(log_str, args.out_file)

    old_task_id = task_id
