import json
import logging
import os
import random
import shutil
import sys
import warnings

import numpy as np
import torch
from methods.fedmuscle_multi_modal import (
    create_fedmuscle_setup,
    client_rep_gen_function,
    aggregation_func,
    client_rep_align_function,
)
from utils.log import log_metrics
from utils.dataset import log_dataset_stats, setup_dataset
from utils.args import args_setup


"""Argumants"""
args = args_setup()

"""Setting up Log Directory"""
exp_name = "basic"
output_dir = os.path.join(
    ".",
    "results",
    f"{args.adapter_method}",
    f"{args.aggregation_method}",
    f"{exp_name}",
    f"{args.public_dataset_name}",
    f"seed_{args.seed}",
    f"n_MLC_clients_{args.num_MLC_clients}",
    f"n_IC100_clients_{args.num_IC100_clients}",
    f"n_IC10_clients_{args.num_IC10_clients}",
    f"n_semantic_seg_clients_{args.num_semantic_segmentation_clients}",
    f"n_yahoo_clients{args.num_yahoo_topic_classification_clients}",
    f"n_epochs_{args.num_epochs}",
    f"lr_{args.lr}",
    f"rank_{args.rank}",
    f"basic",
)

if os.path.exists(output_dir + "/log.txt"):
    print("###############")
    print("Warning! Might overright logs")
    print("###############")
os.makedirs(output_dir, exist_ok=True)

"""Setting up logger"""
print(f"| logging to {output_dir + '/log.txt'}")
logging.basicConfig(
    level=logging.INFO,
    filename=output_dir + "/log.txt",
    filemode="w",
    format="%(message)s",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(f"{vars(args)}\n")
logging.info("output_dir: " + output_dir)
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.INFO)
warnings.filterwarnings("ignore")

"""Saving argumants and the python file in log directory"""
argsfile = os.path.join(output_dir, "args.json")
with open(argsfile, "w", encoding="UTF-8") as f:
    json.dump(vars(args), f)  # Save the argumants in a json file
shutil.copy(__file__, f"{output_dir}/script.py")  # Save the python file

"""Set Device"""
device = torch.device(f"cuda:{args.device}")

"""Iterate over seeds"""
for exp in range(3):
    """Set seed"""
    torch.manual_seed(args.seed + exp)
    torch.cuda.manual_seed(args.seed + exp)
    np.random.seed(args.seed + exp)
    random.seed(args.seed + exp)

    """Load the dataset"""
    dataset_path_dict = setup_dataset(args)

    """Setup clients"""
    if args.aggregation_method == "fedmuscle":
        clients = create_fedmuscle_setup(
            args,
            args.models_path,
            dataset_path_dict,
            device,
        )
    else:
        raise NotImplementedError(
            f"Aggregation method {args.aggregation_method} is not implemeneted. Check the name"
        )

    logging.info("| Loaded clients successfully.")
    log_dataset_stats(clients)

    """Log before starting the training"""
    logging.info(f"| Logging before starting the training")
    for client_id, client in enumerate(clients):
        # Test
        round_metrics = client.local_eval_fun("test")
        log_metrics(
            round_metrics,
            client_id,
            0,
            output_dir,
            "test",
            exp=exp,
            write_from_beginning=True,
        )
        # Train
        round_metrics = client.local_eval_fun("train")
        log_metrics(
            round_metrics,
            client_id,
            0,
            output_dir,
            "train",
            exp=exp,
            write_from_beginning=True,
        )

    pub_dataset_size = len(dataset_path_dict["public_dataset"])
    all_indices = list(range(pub_dataset_size))

    """Iteration"""
    for communication_round in range(1, args.num_epochs + 1):
        logging.info(f"| Communication Round: {communication_round}")

        # Local Update
        for client_idx, client in enumerate(clients):
            logging.info(f"| Local training client {client_idx}")
            client.local_train_func()


        for cl_epoch in range(args.num_cl_epochs):

            avail_indices = all_indices.copy()

            while len(avail_indices) > 0:

                current_batch_size = min(args.pub_batch_size, len(avail_indices))

                # The server randomly samples current batch indices and sends them to the clients
                batch_indices = random.sample(avail_indices, current_batch_size) 

                # Remove the selected indices from previously available indices 
                for idx in batch_indices:
                    avail_indices.remove(idx)

                # Clients obtain the representations corresponding the batch of shared public dataset
                server_rec_rep = {}
                for client_idx, client in enumerate(clients):
                    server_rec_rep[client_idx] = client_rep_gen_function(args, client, dataset_path_dict["public_dataset"], batch_indices, device)

                # The server computes the aggregated representations and the weighting coefficients for each client
                agg_rep, alpha = aggregation_func(args, server_rec_rep, device)


                # Clients train their representation models using Muscle
                for client_idx, client in enumerate(clients):
                    client_rep_align_function(args, client, dataset_path_dict["public_dataset"], batch_indices, agg_rep[client_idx], alpha[client_idx], device)



        # Log the accuracy and save the models
        if communication_round % args.log_interval == 0:
            logging.info("\n")
            logging.info(f"| Logging at Round {communication_round} -- Exp {exp}")
            logging.info(f"| ###############")

            for client_id, client in enumerate(clients):
                # test set
                round_metrics = client.local_eval_fun(evaluation_set="test")
                log_metrics(
                    round_metrics,
                    client_id,
                    communication_round,
                    output_dir,
                    "test",
                    exp=exp,
                )

                # train set
                round_metrics = client.local_eval_fun(evaluation_set="train")
                log_metrics(
                    round_metrics,
                    client_id,
                    communication_round,
                    output_dir,
                    "train",
                    exp=exp,
                )

    """Post training"""
    for post_training_round in range(0, args.post_training):
        logging.info(f"| Post Training Round: {post_training_round}")

        # Local Update
        for client_idx, client in enumerate(clients):
            logging.info(f"| Post training client {client_idx}")
            client.local_train_func(num_epochs=1)

        # Log the accuracy
        logging.info("\n")
        logging.info(f"| Logging at Round {post_training_round} -- Exp {exp}")
        logging.info(f"| ###############")

        for client_id, client in enumerate(clients):
            # test set
            round_metrics = client.local_eval_fun(evaluation_set="test")
            log_metrics(
                round_metrics,
                client_id,
                communication_round + post_training_round + 1,
                output_dir,
                "test",
                exp=exp,
            )

            # train set
            round_metrics = client.local_eval_fun(evaluation_set="train")
            log_metrics(
                round_metrics,
                client_id,
                communication_round + post_training_round + 1,
                output_dir,
                "train",
                exp=exp,
            )
            # train set
            round_metrics = client.local_eval_fun(evaluation_set="train")
            log_metrics(
                round_metrics,
                client_id,
                communication_round + post_training_round + 1,
                output_dir,
                "train",
                exp=exp,
            )
