import os
import random

import matplotlib

from models.Update_domain import DomainClientUpdate, DomainClientUpdate_avg
from utils.fpl import proto_aggregation

matplotlib.use('Agg')
import copy
import numpy as np
import torch
import time

from utils.options import args_parser
from models.Fed import FedAvg
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.evaluate import DeltaWeight,test,evaluate
from utils.forget_event import order_examples_of_forget

def retrain(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)

    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    datasets_name = get_dataset(args)
    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))

    delete_users = args.unlearning_client if isinstance(args.unlearning_client, (list, tuple)) else [args.unlearning_client]

    # RETRAIN
    print('-' * 50)
    print('RETRAIN')

    net_glob = init_model(args).to(args.device)
    model_load_path = f'./save/test/{args.dataset}/learning/{args.save}/{dsf_dir}/{bkd_str}'
    if not os.path.exists(model_load_path):
        legacy_path = f'./save/test/{args.dataset}/learning/{args.save}/{old_dsf_dir}/{bkd_str}'
        if os.path.exists(legacy_path):
            model_load_path = legacy_path
    w_glob = torch.load(f'{model_load_path}/weight_init.pth', map_location=args.device)
    net_glob.load_state_dict(w_glob)

    old_net = copy.deepcopy(net_glob).to(args.device)
    old_w_glob = torch.load(f'{model_load_path}/weight_global.pth', map_location=args.device)
    old_net.load_state_dict(old_w_glob)

    example_stats = [[{} for _ in range(args.num_users)], [{} for _ in range(args.num_users)]]
    print(example_stats)

    retrain_dir = f'./save/test/{args.dataset}/retrain/{args.save}/{dsf_dir}/{bkd_str}'
    os.makedirs(retrain_dir, exist_ok=True)
    base = f'{retrain_dir}/{[datasets_name[i] for i in delete_users]}'

    print("---- before unlearning -----")
    example_stats, _ = evaluate(args=args, train_loaders=train_loaders, test_loaders=test_loaders, net=old_net, example_stats=example_stats, datasets_name=datasets_name, backdoorloader=backdoorloader)

    acc_best = -1
    idx_best = -1
    w_loss = 1000
    time_s = 0
    client_time_records = []
    server_time_records = []
    performance_records = []
    round_idx = 0

    unlearn_epoch = args.unlearn_epoch
    loss_train = [[] for _ in range(args.num_users)]

    loss_train = [[] for _ in range(args.num_users)]

    print(f'**** forget {[datasets_name[i] for i in delete_users]} idx:  *****')
    


    use_proto = getattr(args, "proto", False)
    global_protos = {}

    for epoch in range(args.unlearn_epoch):
        print("============ Train epoch {} ============".format(epoch))
        epoch_start = time.perf_counter()
        client_elapsed = 0
        w_locals = []
        loss_locals = []
        local_protos = [{} for _ in range(args.num_users)]

        for client_idx in range(args.num_users):
            if client_idx in delete_users:
                continue

            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=train_loaders[client_idx]
            )
            t0 = time.perf_counter()
            if use_proto:
                client_model, client_state, client_proto, result = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                client_loss = result[0] if result else 0.0
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, client_loss = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0
            loss_locals.append(client_loss)
            loss_train[client_idx].append(client_loss)
            print(f'Client {datasets_name[client_idx]} | Loss: {client_loss:.4f}')

        w_glob = FedAvg(w_locals)
        net_glob.load_state_dict(w_glob)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - epoch_start - client_elapsed

        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}, Time {:.3f}'.format(epoch, loss_avg, time.perf_counter() - epoch_start))

        end_time = time.perf_counter()
        time_s += end_time - epoch_start
        participants = args.num_users - len(delete_users)
        client_time_records.append(client_elapsed / participants)
        server_time_records.append(server_time)

        # Testing and validation
        print("============ Test epoch {} ============".format(epoch))
        example_stats, g_loss = evaluate(args=args, train_loaders=train_loaders, test_loaders=test_loaders, net=net_glob, example_stats=example_stats, datasets_name=datasets_name, backdoorloader=backdoorloader)
        performance_records.append([round_idx] + g_loss)
        round_idx += 1

        # for client_idx in range(args.num_users):
        #         loss_train[client_idx].append(g_loss[client_idx])

        # w_loss = DeltaWeight(old_w_glob, w_glob)
        # print('Delta Weight:{:.3f}'.format(w_loss))


    torch.save(example_stats, f'{base}_forget_event.pth')
    torch.save(w_glob, f'{base}_weight_global.pth')
    # torch.save(loss_train, f'{base}_loss_train.pth')
    print('*' * 50)
    print("finish")
    print('retrain Time : {:.3f}s'.format(time_s))

    csv_dir = f'./result/csv/{args.dataset}/{args.target}'
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    ul_clients_str = '_'.join(str(i) for i in (args.unlearning_client if isinstance(args.unlearning_client, (list, tuple)) else [args.unlearning_client]))
    bd_clients_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))
    time_file = os.path.join(
        csv_dir,
        f'time_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    perf_file = os.path.join(
        csv_dir,
        f'performance_of_clients_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    with open(time_file, 'w') as f:
        f.write('round,client_time,server_time\n')
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f'{r},{ct},{st}\n')

    with open(perf_file, 'w') as f:
        f.write('round')
        for cid in range(args.num_users):
            f.write(f',client{cid}')
        f.write('\n')
        for record in performance_records:
            f.write(','.join(map(str, record)) + '\n')

