import os, importlib
import numpy as np
import pandas as pd
from scipy.linalg import circulant
import matplotlib.pyplot as plt
import matplotlib as mpl
import wandb
from datetime import datetime
from keras import datasets
import jax
import jax.numpy as jnp
import dm_pix as pix
import sys
import tensorflow as tf

def generate_local_epoch_distribution(client_number: int, RNG: np.random.Generator, iteration_type: str,
                                      does_it_vary: bool, max_iteration: int,
                                      mean: float, std: float, beta: float, coefficient: int) -> np.ndarray:
    match iteration_type:
        case 'constant':
            local_iteration_distribution = np.full(client_number, mean)
        case 'uniform':
            local_iteration_distribution = RNG.uniform(mean - std, mean + std, size=client_number)
        case 'gaussian':
            if does_it_vary:
                distribution_mean = (RNG.dirichlet(np.full(client_number, beta)) * coefficient).round()
                local_iteration_distribution = np.maximum(
                    RNG.normal(distribution_mean, std, (max_iteration, client_number)).round(), 1)
            else:
                local_iteration_distribution = np.maximum(RNG.normal(mean, std, client_number).round(), 1)
        case 'exponential':
            if does_it_vary:
                distribution_mean = (RNG.dirichlet(np.full(client_number, beta)) * coefficient).round()
                local_iteration_distribution = np.maximum(
                    RNG.exponential(distribution_mean, (max_iteration, client_number)).round(), 1)
            else:
                local_iteration_distribution = np.maximum(RNG.exponential(mean, size=client_number).round(), 1)
        case 'dirichlet':
            local_iteration_distribution = (RNG.dirichlet(np.full(client_number, beta)) * coefficient).round()
        case _:
            raise ValueError("Invalid iteration type")
    print(f"The local iteration distribution is {local_iteration_distribution.shape} shaped")
    if local_iteration_distribution.ndim == 1:
        local_iteration_distribution = np.repeat(local_iteration_distribution.reshape(-1, client_number), max_iteration,
                                                 axis=0)
    print(f"The local iteration distribution is expanded to {local_iteration_distribution.shape}")
    return np.int16(local_iteration_distribution)


def split_data_by_labels(data: np.ndarray, labels: np.ndarray) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]:
    split_indices = pd.Series(labels).groupby(labels).indices
    split_data = {label: data[split_indices[label]] for label in split_indices}
    return split_indices, split_data


def sample_data_per_label(sample_number: int, RNG: np.random.Generator, data: np.ndarray,
                          split_labels: dict[int, np.ndarray]) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]:
    sampled_indices = {label: RNG.choice(split_labels[label], sample_number, replace=False) for label in split_labels}
    sampled_split_data = {label: data[sampled_indices[label]] for label in split_labels}
    return sampled_indices, sampled_split_data


def flatten_nested_lists(data: list[list]) -> list:
    return [np.concatenate(data[idx]) for idx in range(len(data))]


def allocate_client_datasets(client_number: int, RNG: np.random.Generator, allocation_type: str, class_ratio: int,
                             split_data: dict[int, np.ndarray], beta: float, data_shape: tuple[int, int, int]) -> tuple[
    list, list]:
    client_data = list([] for _ in range(client_number))
    client_labels = list([] for _ in range(client_number))
    total_dropped_samples = 0
    match allocation_type:
        case "class-clients":  # Each client has data from a single class
            for client in range(client_number):
                client_data[client] = split_data[client]
                client_labels[client] = np.full(len(split_data[client]), client, dtype=np.int8)
                # done with this case

        case "1/n-class-clients":  # Each client has data from n classes with one class being dominant, n = class_ratio
            # This doesn't work when the number of clients is higher than a certain number
            class_number_per_client = client_number * class_ratio
            allocation_map = np.arange(client_number + class_ratio - 1, class_number_per_client, class_ratio - 1,
                                       dtype=np.int8)
            allocation_order = circulant(np.arange(client_number, dtype=np.int8)).T
            allocation_row = 0
            for label, data in split_data.items():
                sample_number = len(data)
                class_length = sample_number // (class_number_per_client)
                excess_samples = sample_number % class_number_per_client
                total_dropped_samples += excess_samples
                classed_data = data[:-excess_samples, :, :] if excess_samples else data[:]
                classed_data = classed_data.reshape((class_number_per_client), class_length, *data_shape)
                split_classed_data = np.split(classed_data, allocation_map, axis=0)

                for client in range(client_number):
                    client_idx = allocation_order[allocation_row, client]
                    temp = split_classed_data[client].reshape(-1, height, width)
                    client_data[client_idx].append(temp)
                    client_labels[client_idx].append(np.full(len(temp), label, dtype=np.int8))
                allocation_row += 1

        case "uniform":
            class_number_per_client = client_number
            for label, data in split_data.items():
                sample_number = len(data)
                class_length = sample_number // (class_number_per_client)
                excess_samples = sample_number % class_number_per_client
                total_dropped_samples += excess_samples
                uniformly_classed_data = data[:-excess_samples, :, :] if excess_samples else data[:]
                uniformly_classed_data = uniformly_classed_data.reshape((class_number_per_client), class_length,
                                                                        *data_shape)
                for client in range(client_number):
                    client_data[client].append(uniformly_classed_data[client])
                    client_labels[client].append(np.full(class_length, label, dtype=np.int8))

        case "random":
            for label, data in split_data.items():
                allocation_map = np.sort(RNG.integers(len(data), size=client_number - 1))
                randomly_classed_data = np.split(data, allocation_map, axis=0)

                for client in range(client_number):
                    client_data[client].append(randomly_classed_data[client])
                    client_labels[client].append(np.full(len(randomly_classed_data[client]), label, dtype=np.int8))

        case "dirichlet":
            sample_rate = RNG.dirichlet(np.full(client_number, beta), size=len(split_data))
            for label, data in split_data.items():
                allocation_map = np.cumsum(np.int32(sample_rate[label] * len(data)))[:-1]
                dirichlet_classed_data = np.split(data, allocation_map, axis=0)

                for client in range(client_number):
                    client_data[client].append(dirichlet_classed_data[client])
                    client_labels[client].append(np.full(len(dirichlet_classed_data[client]), label, dtype=np.int8))
        case _:
            raise ValueError("Invalid allocation type")
    print(f"The number of dropped samples is equal to {total_dropped_samples}")
    if allocation_type != "class-clients":
        client_data, client_labels = flatten_nested_lists(client_data), flatten_nested_lists(client_labels)
        for client in range(client_number):
            randomized_idx = np.arange(len(client_data[client]))
            RNG.shuffle(randomized_idx)
            client_data[client], client_labels[client] = client_data[client][randomized_idx], client_labels[client][
                randomized_idx]

    return client_labels, client_data


def generate_active_client_matrix(inactive_probability: float, RNG: np.random.Generator, max_iteration: int,
                                  client_number: int):
    active_client_matrix = RNG.choice([False, True], (max_iteration, client_number), True,
                                      [inactive_probability, 1 - inactive_probability])
    return active_client_matrix


def plot_data_distribution(client_number: int, client_labels: list[np.ndarray]):
    if client_number <= 20:
        plt.figure(1, figsize=(int(8 * client_number / 10), 8))
    else:
        plt.figure(1, figsize=(int(4 * client_number / 10), 16))
    for client in range(len(client_labels)):
        if client_number <= 20:
            plt.subplot(5, int(np.ceil(client_number / 5)), client + 1)
        else:
            plt.subplot(10, int(np.ceil(client_number / 10)), client + 1)
        plt.hist(client_labels[client], color="lightblue", ec="red", align="left", bins=np.arange(11))
        plt.title("Client " + str(client + 1))
    plt.suptitle("Label Distributions of Clients")
    plt.tight_layout()
    if not os.path.exists('logs'):
        os.makedirs('logs')
    plt.savefig('logs/' + 'Client_Histogram.png')


def load_and_preprocess_dataset(dataset_name, JKEY):
    def normalize(img, mean, std):
        img = img / 255.0
        img = (img - mean) / std
        return img

    def augment(JKEY, img):
        img = pix.random_flip_left_right(JKEY, img)
        img = pix.pad_to_size(img, 40, 40)
        img = pix.random_crop(JKEY, img, (32, 32, 3))
        return img

    match str.upper(dataset_name):
        case 'CIFAR10' | 'CIFAR100':
            dataset = datasets.cifar10 if str.upper(dataset_name) == 'CIFAR10' else datasets.cifar100
            data_shape = (32, 32, 3)
            data_mean, data_std = jnp.array([0.491, 0.482, 0.447]), jnp.array([0.247, 0.243, 0.262])
            (train_images, train_labels), (test_images, test_labels) = dataset.load_data()
            train_labels, test_labels = train_labels.reshape(-1), test_labels.reshape(-1)

            print('Normalizing and augmenting Train Images...')
            train_images = np.asarray(jax.vmap(normalize, in_axes=(0, None, None))(train_images, data_mean, data_std))
            train_images = np.asarray(
                jax.vmap(augment, in_axes=(0, 0))(jax.random.split(JKEY, len(train_images)), train_images))
            print('Train images normalized and augmented.')
            test_images = np.asarray(jax.vmap(normalize, in_axes=(0, None, None))(test_images, data_mean, data_std))
            print(f"The mean and std of training data are {np.mean(train_images)} and {np.std(train_images)}")
            print(f"The mean and std of training data are {np.mean(test_images)} and {np.std(test_images)}")

        case 'MNIST' | 'FASHION_MNIST':
            dataset = datasets.mnist if str.upper(dataset_name) == 'MNIST' else datasets.fashion_mnist
            data_shape = (28, 28, 1)
            data_mean, data_std = (0.1307, 0.3081)
            (train_images, train_labels), (test_images, test_labels) = dataset.load_data()
            train_images, test_images = np.expand_dims(train_images, axis=-1), np.expand_dims(test_images, axis=-1)
            train_labels, test_labels = train_labels.reshape(-1), test_labels.reshape(-1)
            train_images = np.asarray(jax.vmap(normalize, in_axes=(0, None, None))(train_images, data_mean, data_std))
            test_images = np.asarray(jax.vmap(normalize, in_axes=(0, None, None))(test_images, data_mean, data_std))

        case _:
            raise ValueError("Invalid dataset")

    return (train_images, train_labels), (test_images, test_labels), data_shape

def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key.upper(), sep=sep).items())
        else:
            items.append((new_key.upper(), v))
    return dict(items)
def load_and_extract_configs(dir_path, include=None, exclude=None):
    if exclude is None:
        exclude = []
    else:
        print("Exclude: ", exclude)

    curr_path = os.path.dirname(os.path.realpath(__file__))
    param_files = os.listdir(os.path.join(curr_path, dir_path))

    print(param_files)

    if include is not None:
        param_files = [f.split('.')[0] for f in param_files if any(f.split('.')[0] in incl for incl in include)]
    else:
        param_files = [f.split('.')[0] for f in param_files if (f.endswith('.py') and not any(f.split('.')[0] in excl for excl in exclude))]

    configs = []

    if param_files:
        for f in param_files:
            module_name = f"{dir_path.replace('/', '.')}.{f}"
            try:
                module = importlib.import_module(module_name)
                print(f"Loading Configs from {f}.py...")
                config = getattr(module, 'config')
                configs.append(config)
                print("Done.")
            except AttributeError as e:
                print(f"Could not find 'config' in {f}.py. Skipping.")
                print(f"The error is {e}")
            except ImportError as e:
                print(f"Failed to import {f}.py: {str(e)}")
            finally:
                if module_name in sys.modules:
                    del sys.modules[module_name]
    else:
        print("No Python files found based on the include/exclude criteria.")

    return configs


def log(row: pd.Series, step: int, logees: dict[str, any]):
    for key, data in logees.items():
        if isinstance(data, list):
            for idx, datum in enumerate(data):
                wandb.log(data={key: np.mean(datum)})
        else:
            wandb.log(data={key: data})  # , step = step)
    return pd.Series(logees) if row.empty \
        else pd.concat([row, pd.Series(logees)], axis=0)


def get_run_name(config, run, sgd=False):
    name = ""

    # * Model
    name += f"{config.train.model.__name__.lower()}|"

    # * Data
    name += f"{config.data.name}|"
    name += f"{config.data.alloc_type}|"
    if config.data.alloc_type == 'dirichlet':
        name += f"{config.data.beta}|"
        name += f"S-{config.seed}|"
    #name += f"BS-{config.data.batch_size}|"

    # * Server
    name += f"SE-{config.server.num_epochs}|"
    if sgd: name += f"SGD|"
    if sgd: name += f"SLR-{config.server.lr}|"

    # * Worker
    name += f"WE-{config.worker.epoch.mean}|"
    name += f"NW-{config.worker.num}|" # if config.worker.num != 100:
    name += f"TX-{config.worker.tx.fn.__name__}|"
    name += f"LR-{config.worker.tx.lr}|"

    # * Compressor
    if config.compressor.uplink_samples>0:
        name += f"CTrue|"
        name += f"RA{config.compressor.reset_aggregation}|"
        name += f"IR{config.compressor.use_indiv_reference}|"
        name += f"PP{config.compressor.use_posterior_prior}|"
        name += f"UL{config.compressor.uplink_samples}|"
        name += f"DL{config.compressor.downlink_samples}|"
        name += f"CDP{config.compressor.common_dl_prior}|"
        name += f"RS{config.compressor.reuse_samples}|"
        name += f"PK{config.compressor.project_kl_divergences}|"
        name += f"PBK{config.compressor.project_block_kl_divergences}|"
        name += f"AU{config.compressor.adaptive_blocks_ul}|"
        name += f"AD{config.compressor.adaptive_blocks_dl}|"
        name += f"AA{config.compressor.adaptive_avg}|"
        name += f"KU{config.compressor.kl_rate_ul}|"
        name += f"KD{config.compressor.kl_rate_dl}|"
        name += f"SD{config.compressor.split_dl}|"
        name += f"MS{config.compressor.max_block_size}|"
        name += f"D{config.compressor.avg_dev_factor}"
        if config.worker.num != 10:
            name += f"|W{config.worker.num}"
        if config.compressor.block_size != 256:
            name += f"|B{config.compressor.block_size}"
        if config.compressor.sample_size != 256:
            name += f"|S{config.compressor.sample_size}"
        if hasattr(config.compressor, "optimize_prior") and config.compressor.optimize_prior:
            name += f"|OP{config.compressor.optimize_prior}"



    else:
        name += f"CFalse|"


    # * Run
    name += f"R{run}|"
    name += datetime.now().strftime("%y%m%d%H%M%S")
    return name.replace('.', ',').replace('|','-')