#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import random
import matplotlib

matplotlib.use('Agg')
import copy
import numpy as np
import torch
import time

from models.Fed import FedAvg
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.evaluate import DeltaWeight, evaluate

from models.Update_domain import DomainClientUpdate_Hesian_record

from utils import ada_hessain

def rapid_retrain(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)

    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    datasets_name = get_dataset(args)

    # ensure client selections are always treated as lists
    args.unlearning_client = (
        args.unlearning_client
        if isinstance(args.unlearning_client, (list, tuple))
        else [args.unlearning_client]
    )
    args.backdoor_client_idx = (
        args.backdoor_client_idx
        if isinstance(args.backdoor_client_idx, (list, tuple))
        else [args.backdoor_client_idx]
    )
    delete_users = args.unlearning_client

    print('-' * 50)
    print('RAPID RETRAIN (Hessian Optimized, No Proto Aggregation)')

    net_glob = init_model(args)
    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = '_'.join(str(i) for i in args.backdoor_client_idx)
    model_load_path = f'./save/test/{args.dataset}/learning/{args.save}/{dsf_dir}/{bkd_str}'
    if not os.path.exists(model_load_path):
        legacy_path = f'./save/test/{args.dataset}/learning/{args.save}/{old_dsf_dir}/{bkd_str}'
        if os.path.exists(legacy_path):
            model_load_path = legacy_path
    w_glob = torch.load(f'{model_load_path}/weight_init.pth', map_location=args.device)
    net_glob.load_state_dict(w_glob)
    net_glob.to(args.device)

    old_net = copy.deepcopy(net_glob)
    old_w_glob = torch.load(f'{model_load_path}/weight_global.pth', map_location=args.device)
    old_net.load_state_dict(old_w_glob)
    old_net.to(args.device)

    example_stats = [[{} for _ in range(args.num_users)], [{} for _ in range(args.num_users)]]

    save_dir = f'./save/test/{args.dataset}/rapid_retrain/{args.save}/{dsf_dir}/{bkd_str}/'
    os.makedirs(save_dir, exist_ok=True)
    ul_clients_str = '_'.join(str(i) for i in args.unlearning_client)
    base = os.path.join(save_dir, f'{ul_clients_str}')

    print("---- Before Unlearning -----")
    example_stats, _ = evaluate(args=args, train_loaders=train_loaders, test_loaders=test_loaders, 
                                net=old_net, example_stats=example_stats, datasets_name=datasets_name, backdoorloader=backdoorloader)

    loss_train = [[] for _ in range(args.num_users)]
    time_s = 0
    client_time_records = []
    server_time_records = []
    performance_records = []
    round_idx = 0

    for i in range(args.unlearn_epoch):
        start_time = time.perf_counter()
        client_elapsed = 0
        net_glob.train()
        loss_locals = []
        w_locals = []

        print(f"============ Train Epoch {i} ============")
        for client_idx in range(args.num_users):
            if client_idx in delete_users:
                continue

            local = DomainClientUpdate_Hesian_record(args=args, train_loader=train_loaders[client_idx])
            t0 = time.perf_counter()
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            client_elapsed += time.perf_counter() - t0

            w_locals.append(copy.deepcopy(w))
            loss_locals.append(loss)
            print(f' {datasets_name[client_idx]:<11s} | Train Loss: {loss:.4f}')

            del local
            del w
            del loss
            torch.cuda.empty_cache()

        w_glob = FedAvg(w_locals)
        net_glob.load_state_dict(w_glob)
        net_glob.to(args.device)

        server_time = time.perf_counter() - start_time - client_elapsed
        epoch_time = time.perf_counter() - start_time
        time_s += epoch_time
        participants = args.num_users - len(delete_users)
        client_time_records.append(client_elapsed / participants)
        server_time_records.append(server_time)
        print(f'Epoch {i}, Time: {epoch_time:.3f}s, Avg Loss: {sum(loss_locals)/len(loss_locals):.3f}')

        example_stats, g_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=net_glob,
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader
        )
        performance_records.append([round_idx] + g_loss)
        round_idx += 1

        for client_idx in range(args.num_users):
            if client_idx not in delete_users:
                loss_train[client_idx].append(g_loss[client_idx])

        w_loss = DeltaWeight(old_w_glob, w_glob)
        print(f'Delta Weight: {w_loss:.3f}')

        del w_locals
        del loss_locals
        del g_loss
        torch.cuda.empty_cache()

    torch.save(example_stats, f'{base}_forget_event.pth')
    torch.save(w_glob, f'{base}_weight_global.pth')
    torch.save(loss_train, f'{base}_loss_train.pth')

    print('*' * 50)
    print("Rapid Retrain Complete")
    print(f'Total Time: {time_s:.3f}s')

    csv_dir = f'./result/csv/{args.dataset}/{args.target}'
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    bd_clients_str = '_'.join(str(i) for i in args.backdoor_client_idx)
    time_file = os.path.join(
        csv_dir,
        f'time_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    perf_file = os.path.join(
        csv_dir,
        f'performance_of_clients_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    with open(time_file, 'w') as f:
        f.write('round,client_time,server_time\n')
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f'{r},{ct},{st}\n')

    with open(perf_file, 'w') as f:
        f.write('round')
        for cid in range(args.num_users):
            f.write(f',client{cid}')
        f.write('\n')
        for record in performance_records:
            f.write(','.join(map(str, record)) + '\n')

