import os
import random
import time
import copy
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from PIL import Image
from torch import nn

from models.Update_domain import DomainClientUpdate, atk_train, DomainClientUpdate_avg
from models.vggmodule import vgg
from models.backdoor import create_trigger_model
from utils.fpl import proto_aggregation
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, svhn_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar, Lenet5, LeNet, DigitModel
from models.Fed import FedAvg
from models.test import test_img
from utils import data_utils
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.forget_event import compute_forgetting_statistics, order_examples_of_forget, sort_examples_by_forgetting
from utils.evaluate import evaluate
from utils.backdoor_process import backdoor_process, plt_img


def learning(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device(
        f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu != -1 else 'cpu'
    )
    print(args)

    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    datasets_name = get_dataset(args)

    for num, train_loader in enumerate(train_loaders):
        sum_cont = 0
        count = [0] * 10
        for i in range(len(train_loader.dataset)):
            item = train_loader.dataset[i]
            label = item[1] if isinstance(item, (list, tuple)) else item.get('label')
            count[label] += 1
            sum_cont += 1
        print(
            f'Train Dataset {datasets_name[num]}, Each Class Number: {count} , All Number: {sum_cont}'
        )

    for num, test_loader in enumerate(test_loaders):
        sum_cont = 0
        count = [0] * 10
        for i in range(len(test_loader.dataset)):
            item = test_loader.dataset[i]
            label = item[1] if isinstance(item, (list, tuple)) else item.get('label')
            count[label] += 1
            sum_cont += 1
        print(
            f'Test Dataset {datasets_name[num]}, Each Class Number: {count} , All Number: {sum_cont}'
        )

    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    bkd_str = "_".join(
        str(i)
        for i in (
            args.backdoor_client_idx
            if isinstance(args.backdoor_client_idx, (list, tuple))
            else [args.backdoor_client_idx]
        )
    )
    dataset_save = f'./save/datasets/{args.dataset}/{dsf_dir}/{bkd_str}/{args.verify}/'
    os.makedirs(dataset_save, exist_ok=True)
    torch.save(train_loaders, os.path.join(dataset_save, 'train_loaders.pth'))
    torch.save(test_loaders, os.path.join(dataset_save, 'test_loaders.pth'))
    base_path = f'./save/test/{args.dataset}/learning/{args.save}/{dsf_dir}/{bkd_str}'
    os.makedirs(f'{base_path}/models', exist_ok=True)

    if args.pre_train > 0:
        net_glob = init_model(args)
        w_glob = torch.load(
            f'{base_path}/weight_global.pth', map_location=args.device
        )
        w_init = torch.load(
            f'{base_path}/weight_init.pth', map_location=args.device
        )
        example_stats = torch.load(f'{base_path}/forget_event.pth')
        net_glob.load_state_dict(w_glob)
        net_glob.to(args.device)
        print(len(example_stats[0][0]["acc"][1]))
    else:
        net_glob = init_model(args)
        net_glob.to(args.device)
        w_glob = net_glob.state_dict()
        w_init = copy.deepcopy(w_glob)
        example_stats = [[{} for _ in range(args.num_users)] for _ in range(2)]
        print(example_stats)

    torch.save(net_glob, f'{base_path}/models/global_epoch-0.pth')

    # training
    use_proto = getattr(args, "proto", False)
    global_protos = {}
    loss_train = [[] for _ in range(args.num_users)]
    time_s = 0
    acc_best = 0
    client_time_records = []
    server_time_records = []
    performance_records = []

    for iter in range(args.epochs):
        print(f"============ Train epoch {iter} ============")
        round_start = time.perf_counter()
        client_elapsed = 0
        w_locals = []
        loss_locals = []
        client_models = []
        local_protos = [{} for _ in range(args.num_users)]

        for client_idx in range(args.num_users):
            t0 = time.perf_counter()
            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=train_loaders[client_idx]
            )
            if use_proto:
                client_model, client_state, client_proto, result = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                client_loss = result[0] if result else 0.0
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, client_loss = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0
            loss_locals.append(client_loss)
            loss_train[client_idx].append(client_loss)
            client_models.append(copy.deepcopy(client_model).cpu())
            print(f'Client {datasets_name[client_idx]} | Loss: {client_loss:.4f}')

        w_glob = FedAvg(w_locals)
        net_glob.load_state_dict(w_glob)
        net_glob.to(args.device)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - round_start - client_elapsed

        loss_avg = sum(loss_locals) / len(loss_locals)
        print(
            'Round {:3d}, Average loss {:.3f}, Time {:.3f}'.format(
                int(iter), float(loss_avg), time.perf_counter() - round_start
            )
        )

        end_time = time.perf_counter()
        time_s += end_time - round_start
        client_time_records.append(client_elapsed / args.num_users)
        server_time_records.append(server_time)

        print(f"============ Test epoch {iter} ============")
        net_glob.to(args.device)
        example_stats, g_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=net_glob,
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader,
        )
        performance_records.append([iter] + g_loss)

        print("example_stats structure:", example_stats)

        for client_idx in range(args.num_users):
            loss_train[client_idx].append(g_loss[client_idx])

    print('*' * 50)
    print("finish")
    print('training Time : {:.3f}s'.format(time_s))

    torch.save(loss_train, f'{base_path}/loss_train.pth')
    torch.save(example_stats, f'{base_path}/forget_event.pth')
    torch.save(w_glob, f'{base_path}/weight_global.pth')
    torch.save(w_locals, f'{base_path}/weight_local.pth')
    torch.save(w_init, f'{base_path}/weight_init.pth')

    csv_dir = f'./result/csv/{args.dataset}/{args.target}'
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    time_file = os.path.join(
        csv_dir,
        f'time_{args.target}_{args.unlearning_client}_{args.backdoor_client_idx}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    perf_file = os.path.join(
        csv_dir,
        f'performance_of_clients_{args.target}_{args.unlearning_client}_{args.backdoor_client_idx}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    with open(time_file, 'w') as f:
        f.write('round,client_time,server_time\n')
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f'{r},{ct},{st}\n')

    with open(perf_file, 'w') as f:
        f.write('round')
        for cid in range(args.num_users):
            f.write(f',client{cid}')
        f.write('\n')
        for record in performance_records:
            f.write(','.join(map(str, record)) + '\n')
