#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import copy
import numpy as np
import torch
import time
from pathlib import Path
from torch import nn
from torch.utils.data import Subset
from models.Update_domain import DomainClientUpdate, DomainClientUpdate_avg
from models.Fed import FedAvg
from utils.options import args_parser
from utils.evaluate import evaluate
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.pruning_utils import unlearn_prune
from utils.fpl import proto_aggregation
import random

def class_pruning(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)

    # Ensure client index arguments are always represented as lists
    def _ensure_list(value):
        if isinstance(value, str):
            return [int(v) for v in value.split(',') if v.strip()]
        if isinstance(value, (list, tuple)):
            return list(value)
        return [int(value)]

    args.unlearning_client = _ensure_list(args.unlearning_client)
    args.backdoor_client_idx = _ensure_list(args.backdoor_client_idx)

    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    # Infer number of classes from the training data to ensure consistency
    num_classes = None
    for loader in train_loaders:
        dataset = loader.dataset
        if isinstance(dataset, Subset):
            dataset = dataset.dataset
        if hasattr(dataset, 'targets'):
            labels = dataset.targets
        elif hasattr(dataset, 'labels'):
            labels = dataset.labels
        else:
            continue
        if torch.is_tensor(labels):
            labels = labels.tolist()
        num_classes = len(set(labels))
        break
    if num_classes is not None:
        args.num_classes = num_classes

    datasets_name = get_dataset(args)
    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = "_".join(str(i) for i in args.backdoor_client_idx)

    load_base = Path('./save') / args.load / args.dataset
    save_base = Path('./save') / args.save / args.dataset

    model_load_path = load_base / "learning" / args.model / dsf_dir / bkd_str
    if not model_load_path.exists():
        legacy_path = load_base / "learning" / args.model / old_dsf_dir / bkd_str
        if legacy_path.exists():
            model_load_path = legacy_path

    base_dir = save_base / "class_pruning" / dsf_dir / bkd_str
    base_dir.mkdir(parents=True, exist_ok=True)
    delete_users = list(args.unlearning_client)
    ul_clients_str = "_".join(str(i) for i in delete_users)
    base = base_dir / ul_clients_str

    initial_model = init_model(args)
    net_glob = copy.deepcopy(initial_model).to(args.device)
    required_files = ["weight_global.pth", "weight_local.pth"]
    for f in required_files:
        if not (model_load_path / f).exists():
            raise FileNotFoundError(f"Missing {f} in {model_load_path}")
    net_glob.load_state_dict(torch.load(model_load_path / "weight_global.pth", map_location=args.device))
    client_weights = torch.load(model_load_path / "weight_local.pth", map_location=args.device)

    loss_train = [[] for _ in range(args.num_users)]
    acc_best, idx_best = -1, -1

    example_stats = [
        [{"acc": [[], []]} for _ in range(args.num_users)],
        [{"acc": [[], []]} for _ in range(args.num_users)]
    ]


    pruned_weights = unlearn_prune(
        args=args,
        delete_usr=delete_users,
        w_locals=client_weights,
        w_glob=net_glob.state_dict(),
        net_glob=net_glob,
        train_loaders=train_loaders
    )
    net_glob.load_state_dict(pruned_weights)
    
    start_time = time.perf_counter()
    time_s = 0
    client_time_records = []
    server_time_records = []
    performance_records = []
    round_idx = 0

    start_time = time.perf_counter()
    use_proto = getattr(args, "proto", False)
    global_protos = {}

    for epoch in range(args.unlearn_epoch):
        print("============ Train epoch {} ============".format(epoch))
        epoch_start = time.perf_counter()
        client_elapsed = 0
        w_locals = []
        loss_locals = []
        local_protos = [{} for _ in range(args.num_users)]

        for client_idx in range(args.num_users):
            if client_idx in delete_users:
                continue

            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=train_loaders[client_idx]
            )
            t0 = time.perf_counter()
            if use_proto:
                client_model, client_state, client_proto, result = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                client_loss = result[0] if result else 0.0
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, client_loss = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0
            loss_locals.append(client_loss)
            loss_train[client_idx].append(client_loss)
            print(f'Client {datasets_name[client_idx]} | Loss: {client_loss:.4f}')

        w_glob = FedAvg(w_locals)
        net_glob.load_state_dict(w_glob)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - epoch_start - client_elapsed

        example_stats, global_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=copy.deepcopy(net_glob),
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader
        )
        performance_records.append([round_idx] + global_loss)
        round_idx += 1

        # acc = 0.0
        # valid_clients = 0
        # for idx in range(args.num_users):
        #     if idx != args.unlearning_client:
        #         valid_clients += 1
        # current_acc = acc / valid_clients if valid_clients > 0 else 0.0

        # if current_acc > acc_best:
        #     acc_best = current_acc
        #     idx_best = epoch
        #     torch.save(net_glob.state_dict(), f'{base}_best_model.pth')

        # print(f"Epoch {epoch} | Current Acc: {current_acc:.2f} | Best Acc: {acc_best:.2f} (at Epoch {idx_best})")

        if (epoch+1) % 10 == 0 or epoch == args.unlearn_epoch-1:
            torch.save(example_stats, f'{base}_forget_event.pth')
            torch.save(net_glob.state_dict(), f'{base}_weight_global.pth')
            torch.save(loss_train, f'{base}_loss_train.pth')

        time_s += time.perf_counter() - epoch_start
        participants = args.num_users - len(delete_users)
        client_time_records.append(client_elapsed / participants)
        server_time_records.append(server_time)

    torch.save(example_stats, f'{base}_forget_event.pth')
    torch.save(net_glob.state_dict(), f'{base}_weight_global.pth')
    torch.save(loss_train, f'{base}_loss_train.pth')
    print(f'Total time: {time_s:.2f}s')

    csv_dir = f'./result/csv/{args.dataset}/{args.target}'
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    ul_clients_str = "_".join(str(i) for i in args.unlearning_client)
    bd_clients_str = "_".join(str(i) for i in args.backdoor_client_idx)
    time_file = os.path.join(
        csv_dir,
        f'time_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    perf_file = os.path.join(
        csv_dir,
        f'performance_of_clients_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    with open(time_file, 'w') as f:
        f.write('round,client_time,server_time\n')
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f'{r},{ct},{st}\n')

    with open(perf_file, 'w') as f:
        f.write('round')
        for cid in range(args.num_users):
            f.write(f',client{cid}')
        f.write('\n')
        for record in performance_records:
            f.write(','.join(map(str, record)) + '\n')

