import os
from utils.configs import config_path, config
print('Configs loaded from', config_path)
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu
# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'false' # default 'false'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # default 'true' for XLA to work with JAX
# os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=' + str(config.n_clients)

import jax
import jax.numpy as jnp  # JAX NumPy
from utils import evo
import numpy as np
import wandb
from backprop import dataset
from backprop import combine_sl as sl
import chex
from evosax import NetworkMapper, ParameterReshaper
from flax import serialization
from tqdm import tqdm
from wandb_code import save_code
import math
from utils.functions import *
from utils.regression import *

# cosine distance
def cosine_similarity_error(flat_state, target_states, origin_state):
    norm_delta_flat = (flat_state - origin_state) / jnp.linalg.norm(flat_state - origin_state)
    norm_delta_targets = (target_states - origin_state) / jnp.linalg.norm(target_states - origin_state)
    return jnp.mean(jnp.abs(1 - jnp.dot(norm_delta_flat, norm_delta_targets.T)))

def add_noise(x, std, uniform, rng):
    if uniform:
        return x + jax.random.uniform(rng, shape=x.shape) * std
    return x + jax.random.normal(rng, shape=x.shape) * std

def report_error(epoch, flat_state, origin_state, target_states):
    o = jnp.concatenate(origin_state)
    t = jnp.mean(target_states, axis=0)
    c = flat_state
    compression_error = jnp.mean((c - t) ** 2)
    origin_error = jnp.mean((o - t) ** 2)
    compression_cosine_error = cosine_similarity_error(c, t, o)
    error_ratio = compression_error / origin_error

    # error histogram per parameter
    if config.log_histogram:
        error_per_param = jnp.abs(c - t)
        min_e, max_e = min(error_per_param), max(error_per_param)
        error_histogram = np.histogram(error_per_param, bins=10, range=(min_e, max_e))
        origin_error_per_param = jnp.abs(o - t)
        min_e, max_e = min(origin_error_per_param), (origin_error_per_param).mean()
        origin_error_histogram = np.histogram(origin_error_per_param, bins=10, range=(min_e, max_e))
        wandb.log({
            'Round': epoch,
            'Error Histogram': wandb.Histogram(np_histogram=error_histogram),
            'Origin Error Histogram': wandb.Histogram(np_histogram=origin_error_histogram),
        })
    wandb.log({
        'Round': epoch,
        'Compression Error': compression_error,
        'Cosine Error': compression_cosine_error,
        'Origin Error': origin_error,
        'Error Ratio': error_ratio,
    })


class TaskManager:
    def __init__(self, rng: chex.PRNGKey, args):
        self.args = args
        self.train_ds, self.test_ds = dataset.get_datasets(wandb.config.dataset.lower(),
                                                           n_clients=args.n_clients,
                                                           n_shards_per_client=args.n_shards_per_client,
                                                           iid=args.dist == 'IID',
                                                           use_max_padding=args.use_max_padding)

        self.network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config)
        rng, init_rng = jax.random.split(rng)
        self.state = sl.create_train_state(init_rng, self.network, self.args.lr, self.args.momentum)
        self.param_reshaper = ParameterReshaper(self.state.params)
        total_params = int(self.param_reshaper.total_params)
        self.part_size = math.ceil(total_params / self.args.parts)
        self.padding = self.part_size * self.args.parts - total_params
        self.strategy, self.es_params = evo.get_strategy_and_params(self.args.pop_size, total_params, self.part_size, self.padding, self.args)
        self.server = self.strategy.initialize(init_rng, self.es_params)
        self.flat_params = self.param_reshaper.network_to_flat(self.state.params)
        self.grad_backlog = []
        self.update_backlog = []
        self.last_global_grad = None
        sl.init_param_reshaper(self.param_reshaper)
        # self.last_client_states = None

        print('Number of Params:', sum(x.size for x in jax.jax.tree_util.tree_leaves(self.state.params)))


    def run(self, rng: chex.PRNGKey):
        X, y = self.train_ds['image'], self.train_ds['label']
        # X = jnp.array(X)
        # y = jnp.array(y)
        X, y = jax.device_put(X, device=jax.devices('cpu')[0]), jax.device_put(y, device=jax.devices('cpu')[0])
        for epoch in tqdm(range(1, self.args.n_rounds + 1)):
            rng, batch_rng, rng_ask, rng_train = jax.random.split(rng, 4)
            if epoch > 1:
                self.strategy.update_grads(self.last_global_grad)
            x, g, self.server = self.strategy.ask_strategy(rng_ask, self.server, self.es_params)
            # print('Rank: ', jnp.linalg.matrix_rank(x))

            w, v, loss, acc = jax.vmap(sl.train_epoch2_m, in_axes=(None, None, None, 0, 0, None, None, None, None, None))(self.flat_params, x, g, X, y, self.args.batch_size, self.args.lr, self.args.parts, self.padding, batch_rng)
            # self.flat_params = self.flat_params + jnp.dot(w, x)
            # for c_epoch in range(1, self.args.client_epoch):
            #     rng, batch_rng = jax.random.split(rng)
            #     w , loss, acc = jax.vmap(sl.train_epoch, in_axes=(None, None, 0, 0, None, None, None))(self.flat_params, x, X, y, self.args.batch_size, self.args.lr, self.args.parts, self.padding, batch_rng)
                # local_mean = local_mean + jnp.dot(w, x)
            # w, loss, acc = sl.train_epoch(self.server.mean, x, X[0], y[0], self.args.batch_size, self.args.lr, batch_rng)
            w_fitness = jnp.mean(w, axis=0)
            v_fitness = jnp.mean(v, axis=0)
            fitness = jnp.concatenate([w_fitness, v_fitness])
            # print('FITNESS:', w_fitness.mean(), v_fitness.mean())
            # make fitness to numpy from jnp
            percentage = 0.7
            ss, aa = sum_and_average(fitness.reshape(-1), percentage)
            num_effective_fitness = ss + 1
            # sparsify the fitness based on effective fitness
            fitness = sparsify(fitness, 1 - num_effective_fitness / fitness.size)
            w_fitness = fitness[:w_fitness.size]
            v_fitness = fitness[w_fitness.size:]
            self.server = self.strategy.tell_strategy(x, g, w_fitness, v_fitness, self.server, self.es_params)
            self.flat_params = self.flat_params + self.server.mean
            self.last_global_grad = self.server.mean
            self.state = self.state.replace(params=self.param_reshaper.reshape_single_net(self.flat_params))
            if epoch % self.args.eval_every == 0:
                rng, eval_rng = jax.random.split(rng)
                self.eval_model(epoch, w_fitness.reshape(-1), v_fitness.reshape(-1),
                                loss, acc, ss, aa, eval_rng)

        if self.args.save_model:
            if not os.path.exists(self.args.save_dir):
                os.makedirs(self.args.save_dir)
            model_path = os.path.join(self.args.save_dir, 'model.pkl')
            bytes_data = serialization.to_bytes(self.state.params)
            with open(model_path, 'wb') as f:
                f.write(bytes_data)
            model_art = wandb.Artifact('model', type='model')
            model_art.add_file(model_path, name='model.pkl')
            wandb.log_artifact(model_art)

    def eval_model(self, epoch, w_fitness, v_fitness, loss, acc, ss, aa, rng: chex.PRNGKey):
        test_loss, test_accuracy = sl.eval_model(self.state.params, self.test_ds, rng)
        wlog = {
            'Round': epoch,
            'Global Loss': test_loss,
            'Local Epoch': epoch * self.args.client_epoch,
            'Local Loss': loss.mean(),
            'Local Accuracy': acc.mean(),
            # 'Local L1 Loss': l1_loss.mean(),
            # 'Local L2 Loss': l2_loss.mean(),
            'Global Accuracy': test_accuracy,
            'Random Fitness': jnp.abs(w_fitness).mean(),
            'Grad Fitness': jnp.abs(v_fitness).mean(),
            # 'Fitness': wandb.Histogram(np_histogram=np.histogram(fitness, bins=512)),
            'SS': ss,
            'AA': aa,
        }
        # if self.args.fairness:
        #     wlog.update({
        #         'Min Class Accuracy': jnp.min(class_accuracy),
        #         'Max Class Accuracy': jnp.max(class_accuracy),
        #         'Diff Class Accuracy': jnp.max(class_accuracy) - jnp.min(class_accuracy),
        #         'Fairness Metric': fairness_metric,
        #     })
        #     wlog.update({f'Class {i} Accuracy': a for i, a in enumerate(class_accuracy)})
        wandb.log(wlog)
        # print(
        #     f'Round: {epoch}, Global Loss: {test_loss}, Local Loss: {loss.mean()}, Local Accuracy: {acc.mean()}, Global Accuracy: {test_accuracy}, SS: {ss}, AA: {aa}'
        # )

def run():
    wandb.init(project='MA-LoRA', config=config)
    args = wandb.config
    wandb.run.name = '{}-{} s{} -- {}'.format(args.dataset, name_me(args), args.seed, wandb.run.id)
    save_code()

    rng = jax.random.PRNGKey(args.seed)
    rng, rng_init, rng_run = jax.random.split(rng, 3)
    manager = TaskManager(rng_init, args)
    manager.run(rng_run)
    wandb.finish()

if __name__ == '__main__':
    run()