#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import copy
import numpy as np
import torch
import time
from torch import nn
from models.Update_domain import DomainClientUpdate, DomainClientUpdate_avg, DomainClientUnlearningBl3
from models.Fed import FedAvg
from utils.options import args_parser
from utils.evaluate import evaluate
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.increase_loss_utils import get_distance
from utils.fpl import proto_aggregation
import random

def increase_loss(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)

    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    datasets_name = get_dataset(args)
    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))

    base_dir = f'./save/test/{args.dataset}/increase_loss/{args.save}/{dsf_dir}/{bkd_str}'
    os.makedirs(base_dir, exist_ok=True)
    delete_users = (
        args.unlearning_client
        if isinstance(args.unlearning_client, (list, tuple))
        else [args.unlearning_client]
    )
    ul_clients_str = "_".join(str(i) for i in delete_users)
    base = f"{base_dir}/{ul_clients_str}"

    initial_model = init_model(args)
    net_glob = copy.deepcopy(initial_model).to(args.device)

    model_load_path = f'./save/test/{args.dataset}/learning/{args.save}/{dsf_dir}/{bkd_str}'
    if not os.path.exists(model_load_path):
        legacy_path = f'./save/test/{args.dataset}/learning/{args.save}/{old_dsf_dir}/{bkd_str}'
        if os.path.exists(legacy_path):
            model_load_path = legacy_path
    net_glob.load_state_dict(torch.load(f'{model_load_path}/weight_global.pth', map_location=args.device))
    client_weights = torch.load(f'{model_load_path}/weight_local.pth', map_location=args.device)
    example_stats = [[{} for _ in range(args.num_users)], [{} for _ in range(args.num_users)]]
    loss_train = [[] for _ in range(args.num_users)]
    acc_best, idx_best = -1, -1

    example_stats, _ = evaluate(
        args=args,
        train_loaders=train_loaders,
        test_loaders=test_loaders,
        net=copy.deepcopy(net_glob),
        example_stats=example_stats,
        datasets_name=datasets_name,
        backdoorloader=backdoorloader
    )

    updated_models = []
    start_time = time.perf_counter()
    for ul_idx in delete_users:
        net_unlearning_client = copy.deepcopy(initial_model).to(args.device)
        net_unlearning_client.load_state_dict(client_weights[ul_idx])

        num_users = args.num_users
        net_ref_vec = num_users / (num_users - 1) * nn.utils.parameters_to_vector(net_glob.parameters()) \
            - 1 / (num_users - 1) * nn.utils.parameters_to_vector(net_unlearning_client.parameters())
        net_ref = copy.deepcopy(initial_model)
        nn.utils.vector_to_parameters(net_ref_vec, net_ref.parameters())

        dist_ref_random_lst = []
        for _ in range(10):
            tmp_model = init_model(args)
            dist_ref_random_lst.append(get_distance(net_ref, tmp_model))
        dist_ref_random_lst = torch.tensor(dist_ref_random_lst)
        print(f'Mean distance of Reference Model to random: {torch.mean(dist_ref_random_lst)}')
        threshold = torch.mean(dist_ref_random_lst) / 3
        print(f'Radius for model_ref: {threshold}')
        dist_ref_party = get_distance(net_ref, net_unlearning_client)
        print(f'Distance of Reference Model to unlearning_model: {dist_ref_party}')

        unlearning = DomainClientUnlearningBl3(
            args=args,
            train_loader=train_loaders[ul_idx],
            threshold=threshold
        )
        w = unlearning.train(
            net=copy.deepcopy(net_ref).to(args.device),
            net_ref=net_ref.to(args.device),
            net_unlearning_client=net_unlearning_client.to(args.device)
        )
        updated_models.append(w)

    if updated_models:
        avg_w = FedAvg(updated_models)
        net_glob.load_state_dict(avg_w)
    time_s = time.perf_counter() - start_time
    client_time_records = []
    server_time_records = []
    performance_records = []
    round_idx = 0

    use_proto = getattr(args, "proto", False)
    global_protos = {}

    for epoch in range(args.unlearn_epoch):
        print("============ Train epoch {} ============".format(epoch))
        epoch_start = time.perf_counter()
        client_elapsed = 0
        w_locals = []
        loss_locals = []
        local_protos = [{} for _ in range(args.num_users)]

        for client_idx in range(args.num_users):
            if client_idx in delete_users:
                continue

            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=train_loaders[client_idx]
            )
            t0 = time.perf_counter()
            if use_proto:
                client_model, client_state, client_proto, result = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                client_loss = result[0] if result else 0.0
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, client_loss = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0
            loss_locals.append(client_loss)
            loss_train[client_idx].append(client_loss)
            print(f'Client {datasets_name[client_idx]} | Loss: {client_loss:.4f}')

        w_glob = FedAvg(w_locals)
        net_glob.load_state_dict(w_glob)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - epoch_start - client_elapsed

        example_stats, global_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=copy.deepcopy(net_glob),
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader
        )
        performance_records.append([round_idx] + global_loss)
        round_idx += 1

        # current_acc = np.mean([example_stats[1][idx]["acc"][1][-1] 
        #                      for idx in range(args.num_users) if idx != args.unlearning_client])
        # if current_acc > acc_best:
        #     acc_best = current_acc
        #     idx_best = epoch
        #     torch.save(net_glob.state_dict(), f'{base}_best_model.pth')

        if (epoch+1) % 10 == 0 or epoch == args.unlearn_epoch-1:
            torch.save(example_stats, f'{base}_forget_event.pth')
            torch.save(net_glob.state_dict(), f'{base}_weight_global.pth')
            torch.save(loss_train, f'{base}_loss_train.pth')

        time_s += time.perf_counter() - epoch_start
        participants = args.num_users - len(delete_users)
        client_time_records.append(client_elapsed / participants)
        server_time_records.append(server_time)

    torch.save(example_stats, f'{base}_forget_event.pth')
    torch.save(net_glob.state_dict(), f'{base}_weight_global.pth')
    torch.save(loss_train, f'{base}_loss_train.pth')
    print(f'Total time: {time_s:.2f}s')

    csv_dir = f'./result/csv/{args.dataset}/{args.target}'
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    ul_clients_str = '_'.join(str(i) for i in delete_users)
    bd_clients_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))
    time_file = os.path.join(
        csv_dir,
        f'time_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    perf_file = os.path.join(
        csv_dir,
        f'performance_of_clients_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    with open(time_file, 'w') as f:
        f.write('round,client_time,server_time\n')
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f'{r},{ct},{st}\n')

    with open(perf_file, 'w') as f:
        f.write('round')
        for cid in range(args.num_users):
            f.write(f',client{cid}')
        f.write('\n')
        for record in performance_records:
            f.write(','.join(map(str, record)) + '\n')

