import pickle
import json

import torch
from tqdm import tqdm

from networks import nn_registry
from src.metric import Metrics
from src.dataloader import fetch_trainloader
from src import fedlearning_registry
from src.attack import Attacker, grad_inv
from src.compress import compress_registry
from utils import *

            
def main(config_file):
    config = load_config(config_file)
    output_dir = init_outputfolder(config)
    logger = init_logger(config, output_dir)

    # Load dataset and fetch the data
    train_loader = fetch_trainloader(config, shuffle=True)
    rec_samples, ori_samples = [], []
    
    # if eris prepare mask once
    if config.compress == "eris":
        model = nn_registry[config.model](config)
        d, eris_k, eris_gamma, server_ref_s = init_eris_state(model, fl_rounds=-1,k=None, k_frac=config.k_frac)
        model_params = get_parameters_from_model(model)
        masks = create_mask(model_params, config.n_aggregators, seed=1)
    # grad_mean, grad_std = [], []
    # img_mean, img_std = [], []

    for batch_idx, (x, y) in tqdm(enumerate(train_loader), total=config.n_val_batches):
        if batch_idx == config.n_val_batches:
            break

        criterion = cross_entropy_for_onehot
        model = nn_registry[config.model](config)

        onehot = label_to_onehot(y, num_classes=config.num_classes)
        x, y, onehot, model = preprocess(config, x, y, onehot, model)
        # img_mean.append(torch.mean(x).item())
        # img_std.append(torch.std(x).item())

        # federated learning algorithm on a single device
        fedalg = fedlearning_registry[config.fedalg](criterion, model, config) 
        grad = fedalg.client_grad(x, onehot)
        
        # calculate mean and std of the gradients for reporting
        # grad_mean.append(torch.mean(torch.cat([g.flatten() for g in grad])))
        # grad_std.append(torch.std(torch.cat([g.flatten() for g in grad])))

        if config.compress == "random":
            # random noise as a baseline
            recon_data = torch.randn_like(x).to(config.device) * 0.2 + 0.5
        
        else:
            # gradient postprocessing
            if config.compress != "none":
                
                if config.compress == "eris":
                    compressor = compress_registry["eris"](eris_k, config)

                    # choose the aggregator shard this client will expose
                    rng = np.random.default_rng(1)  # or your (seed + rnd*1000 + cid)
                    aggregator_id = int(rng.integers(0, config.n_aggregators))

                    # prefit builds per-tensor keep masks + scale
                    compressor.prefit(
                        grads_list=grad,
                        masks=masks,
                        aggregator_id=aggregator_id,
                        k=eris_k,
                        seed=1
                    )

                    for i, g in enumerate(grad):
                        grad[i] = compressor.decompress(compressor.compress(g))

                elif config.compress == "eris_partial":

                    # compressor = compress_registry["eris_partial"](config)
                    # # keep ≈ d / n_aggregators randomly (like your DLG/iDLG partial split)
                    # compressor.prefit(
                    #     grads_list=grad,
                    #     n_splits=config.n_aggregators,
                    #     seed=1  # or (config.seed + rnd*1000 + cid)
                    #     # k=...  # optional: set explicit k if you want different keep size
                    # )
                    # n_kept_total = 0
                    # for i, g in enumerate(grad):
                    #     grad[i], n_kept = compressor.decompress(compressor.compress(g))
                    #     n_kept_total += n_kept
                    # print(f"Total n. kept in grad: {n_kept_total}")

                    gradient_list = []
                    for param in grad:
                        gradient_list.append(param.cpu().data.numpy())

                    # Flat and select only one split, zeroing out the others
                    w, s = concatenate_weights(gradient_list, n_splits=config.n_aggregators, random_seed=1)
                    r = deconcatenate_weights(w,s)

                    # update gradient
                    update_model_parameters(grad, r)
                    
        
                else:
                    compressor = compress_registry[config.compress](config)
                    if getattr(compressor, "requires_prefit", False):
                        compressor.prefit(grad)  # compute global threshold once per step - only for pruning
                    for i, g in enumerate(grad):
                        compressed_res = compressor.compress(g)
                        grad[i] = compressor.decompress(compressed_res)

            # initialize an attacker and perform the attack 
            attacker = Attacker(config, criterion)
            # attacker.init_attacker_models(config)
            recon_data = grad_inv(attacker, grad, x, onehot, model, config, logger)

            synth_data, recon_data = attacker.joint_postprocess(recon_data, y) 
            # recon_data = synth_data
        
        rec_samples.append(recon_data)
        ori_samples.append(x)

    # concatenate all samples
    rec_samples = torch.cat(rec_samples, dim=0)
    ori_samples = torch.cat(ori_samples, dim=0)
    
    # logger.info(f"Gradient mean: {torch.mean(torch.tensor(grad_mean)):.4f} std: {torch.mean(torch.tensor(grad_std)):.4f}")
    # logger.info(f"Image mean: {np.mean(img_mean):.4f} std: {np.mean(img_std):.4f}")
    
    # Report the result first 
    logger.info("=== Evaluate the performance ====")
    metrics = Metrics(config)
    snr, std_snr, ssim, std_ssim, jaccard, std_jaccard, lpips, std_lpips = metrics.evaluate(ori_samples, rec_samples, logger)

    logger.info("\nPSNR: {:.3f} SSIM: {:.3f} Jaccard {:.3f} Lpips {:.3f}".format(snr, ssim, jaccard, lpips))

    save_batch(output_dir, ori_samples, rec_samples)

    record = {"snr":snr, "std_snr":std_snr, "ssim":ssim, "std_ssim":std_ssim, "jaccard":jaccard, "std_jaccard":std_jaccard, "lpips":lpips, "std_lpips":std_lpips}
    with open(os.path.join(output_dir, config.fedalg+".json"), "w") as fp:
        json.dump(record, fp)

if __name__ == '__main__':
    torch.manual_seed(0)
    main("config.yaml")
