import torch
import time
import copy
import re
import random
import numpy as np

from sim.algorithms.FL_base import SFLClient, SFLServer
from sim.data.datasets import build_dataset
from sim.data.partition import build_partition
from sim.models.build_models import build_model
from sim.utils.record_utils import logconfig, add_log, record_exp_result2
from sim.utils.utils import setup_seed, replace_with_samples, label_distribution, distance, relu, client_probabilities, initialize_state_probabilities
from sim.utils.optim_utils import OptimKit, LrUpdater
from torch.utils.data import Dataset, DataLoader, Subset
from sim.utils.options import args_parser
from collections import defaultdict

args = args_parser()

torch.set_num_threads(4)
setup_seed(args.seed)
device = torch.device("cuda:{}".format(args.device) if torch.cuda.is_available() else "cpu")
args.beta = [int(args.beta[0]), args.beta[1]] if args.partition == 'exdir' else args.beta

def customize_record_name(args):
    if args.partition == 'exdir':
        partition = f'{args.partition}{args.beta[0]},{args.beta[1]}'
    elif args.partition == 'iid':
        partition = f'{args.partition}'
    elif args.partition == 'dir':
        partition = f'{args.partition}{args.beta[0]}'
    record_name = f'DRSR_N{args.N}_M{args.M}_S{args.S}_R{args.R}_T{args.T}_K{args.K}_{args.m}_{args.d}_{partition}'\
                + f'_seed{args.seed}_{args.capacity}_{args.distance}_set{args.set}'
    return record_name
record_name = customize_record_name(args)

# Our scenario has a much larger amount of state data than the client's data capacity
# We have made slight adjustments to the DRSR algorithm based on this.
def drsr_update_streaming(current_indices, new_indices, max_size, get_label, theta_t=0.5):
    current_indices = list(current_indices)
    new_indices = list(new_indices)

    # Count the number of samples for each label in the current and new indices
    current_label_count = defaultdict(int)
    current_label_to_indices = defaultdict(list)
    for idx in current_indices:
        label = get_label(idx)
        current_label_count[label] += 1
        current_label_to_indices[label].append(idx)

    new_label_count = defaultdict(int)
    new_label_to_indices = defaultdict(list)
    for idx in new_indices:
        label = get_label(idx)
        new_label_count[label] += 1
        new_label_to_indices[label].append(idx)

    all_labels = set(current_label_count.keys()).union(new_label_count.keys())

    target_label_count = dict()
    for label in all_labels:
        n_old = current_label_count.get(label, 0)
        n_new = new_label_count.get(label, 0)
        target_label_count[label] = (1 - theta_t) * n_old + theta_t * n_new

    # Normalize the target label count
    total_target = sum(target_label_count.values())
    scaling_factor = max_size / total_target if total_target > max_size else 1.0
    for label in target_label_count:
        target_label_count[label] = int(target_label_count[label] * scaling_factor)

    updated_indices = []

    for label in all_labels:
        needed = target_label_count[label]
        selected = []

        if label in new_label_to_indices:
            selected.extend(new_label_to_indices[label][:needed])

        if len(selected) < needed and label in current_label_to_indices:
            remaining = needed - len(selected)
            selected.extend(current_label_to_indices[label][:remaining])

        updated_indices.extend(selected)

    # Ensure we don't exceed the max size
    if len(updated_indices) < max_size:
        available_indices = list(set(current_indices + new_indices) - set(updated_indices))
        if available_indices:
            additional = np.random.choice(available_indices, min(max_size - len(updated_indices), len(available_indices)), replace=False)
            updated_indices.extend(additional.tolist())

    np.random.shuffle(updated_indices)

    return updated_indices

def main():
    global args, record_name, device
    logconfig(name=record_name, flag=args.log)
    add_log('{}'.format(args), flag=args.log)
    add_log('record_name: {}'.format(record_name), flag=args.log)
    
    client = SFLClient()
    server = SFLServer()

    state_dataidx_map = build_partition(args.d, args.M, args.partition, args.beta)
    
    train_dataset, test_dataset = build_dataset(args.d)

    client.setup_test_dataset(test_dataset)
    
    global_model = build_model(model=args.m, dataset=args.d)
    server.setup_model(global_model.to(device))
    add_log('{}'.format(global_model), flag=args.log)

    client_optim_kit = OptimKit(optim_name=args.optim, batch_size=args.batch_size, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    client_optim_kit.setup_lr_updater(LrUpdater.exponential_lr_updater, mul=args.lr_decay)
    client.setup_optim_kit(client_optim_kit)
    client.setup_criterion(torch.nn.CrossEntropyLoss())
    server.setup_optim_settings(lr=args.global_lr)

    pi_list, data_map_list, nonzero_indices_list = initialize_state_probabilities(args.N, args.M, args.S, args.set)
    add_log('pi_list: {}'.format(pi_list), flag=args.log)

    client_probs = client_probabilities(args.N)
    record_exp_result2(record_name, {'round':0})
    for round in range(0, args.R):
        selected_clients = server.select_clients_prob(args.N, client_probs)
        add_log('selected clients: {}'.format(selected_clients), flag=args.log)
        local_param_list = []
        theta_t_list = np.random.rand(args.T)
        for c_id in selected_clients:
            server.setup_temp_model(copy.deepcopy(server.global_model.to(device)))
            state_list = np.random.choice(args.M, args.T, p=pi_list[c_id], replace=True)
            add_log('state_list: {}'.format(state_list), flag=args.log)
            for idx, state in enumerate(state_list):
                if round == 0:
                    if len(state_dataidx_map[state]) <= args.capacity:
                        data_map_list[c_id] = np.array(state_dataidx_map[state])
                    else:
                        data_map_list[c_id] = np.random.choice(state_dataidx_map[state], args.capacity, replace=False)
                else:
                    if len(state_dataidx_map[state]) <= args.capacity:
                        data_map_list[c_id] = np.array(state_dataidx_map[state])
                    else:
                        data_map_list[c_id] = drsr_update_streaming(data_map_list[c_id], state_dataidx_map[state], args.capacity, lambda x: train_labels[x].item(), theta_t=theta_t_list[idx])
            
                local_dataset = Subset(train_dataset, data_map_list[c_id])
                local_param, local_delta = client.local_update_step(local_dataset=local_dataset, model=copy.deepcopy(server.temp_model), num_steps=args.K, device=device, clip=args.clip)
                torch.nn.utils.vector_to_parameters(local_param, server.temp_model.parameters())

            local_param_list.append(local_param)

        param_after_FL = server.aggregate_update(local_param_list)
        torch.nn.utils.vector_to_parameters(param_after_FL, server.global_model.parameters())
        
        client.optim_kit.update_lr()

        test_losses, test_top1, test_top5 = client.evaluate_dataset(model=server.global_model, dataset=client.test_dataset, device=args.device)
        add_log("Round {}'s server  test  acc: {:6.2f}%, test  loss: {:.4f}".format(round+1, test_top1.avg, test_losses.avg), 'red', flag=args.log)

        record_exp_result2(record_name, {'round':round+1, 'test_loss'  : test_losses.avg, 'test_top1'  : test_top1.avg})
    
    if args.save_model == 1:
        torch.save({'model': torch.nn.utils.parameters_to_vector(server.global_model.parameters())}, './save_model/{}.pt'.format(record_name))

if __name__ == '__main__':
    main()