import os
from utils.configs import config_path, config
print('Configs loaded from', config_path)
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu
# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'false' # default 'false'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # default 'true' for XLA to work with JAX
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"  # Use 100% of GPU memory

# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=' + str(config.n_clients)

import jax
import jax.numpy as jnp  # JAX NumPy
from utils import evo
import numpy as np
import wandb
from backprop import sl, dataset
import chex
from evosax import NetworkMapper, ParameterReshaper, FitnessShaper
from flax import serialization
from tqdm import tqdm
from wandb_code import save_code
import math
from utils.functions import *
from utils.regression import *

# cosine distance
def cosine_similarity_error(flat_state, target_states, origin_state):
    norm_delta_flat = (flat_state - origin_state) / jnp.linalg.norm(flat_state - origin_state)
    norm_delta_targets = (target_states - origin_state) / jnp.linalg.norm(target_states - origin_state)
    return jnp.mean(jnp.abs(1 - jnp.dot(norm_delta_flat, norm_delta_targets.T)))

def add_noise(x, std, uniform, rng):
    if uniform:
        return x + jax.random.uniform(rng, shape=x.shape) * std
    return x + jax.random.normal(rng, shape=x.shape) * std

def report_error(epoch, flat_state, origin_state, target_states):
    o = jnp.concatenate(origin_state)
    t = jnp.mean(target_states, axis=0)
    c = flat_state
    compression_error = jnp.mean((c - t) ** 2)
    origin_error = jnp.mean((o - t) ** 2)
    compression_cosine_error = cosine_similarity_error(c, t, o)
    error_ratio = compression_error / origin_error

    # error histogram per parameter
    if config.log_histogram:
        error_per_param = jnp.abs(c - t)
        min_e, max_e = min(error_per_param), max(error_per_param)
        error_histogram = np.histogram(error_per_param, bins=10, range=(min_e, max_e))
        origin_error_per_param = jnp.abs(o - t)
        min_e, max_e = min(origin_error_per_param), (origin_error_per_param).mean()
        origin_error_histogram = np.histogram(origin_error_per_param, bins=10, range=(min_e, max_e))
        wandb.log({
            'Round': epoch,
            'Error Histogram': wandb.Histogram(np_histogram=error_histogram),
            'Origin Error Histogram': wandb.Histogram(np_histogram=origin_error_histogram),
        })
    wandb.log({
        'Round': epoch,
        'Compression Error': compression_error,
        'Cosine Error': compression_cosine_error,
        'Origin Error': origin_error,
        'Error Ratio': error_ratio,
    })


class TaskManager:
    def __init__(self, rng: chex.PRNGKey, args):
        self.args = args
        self.train_ds, self.test_ds = dataset.get_datasets(wandb.config.dataset.lower(),
                                                           n_clients=args.n_clients,
                                                           n_shards_per_client=args.n_shards_per_client,
                                                           iid=args.dist == 'IID',
                                                           use_max_padding=args.use_max_padding)

        self.network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config)
        rng, init_rng = jax.random.split(rng)
        self.state = sl.create_train_state(init_rng, self.network, self.args.lr, self.args.momentum)
        self.param_reshaper = ParameterReshaper(self.state.params)
        self.total_params = int(self.param_reshaper.total_params)
        part_size = math.ceil(self.total_params /  self.args.parts)
        self.padding = part_size * self.args.parts - self.total_params
        self.strategy, self.es_params = evo.get_strategy_and_params(self.args.pop_size, part_size, part_size, self.padding, self.args)
        self.fit_shaper = FitnessShaper(centered_rank=self.args.centered_rank, z_score=self.args.z_score,
                                        w_decay=self.args.w_decay, maximize=self.args.maximize)
        server = self.strategy.initialize(init_rng, self.es_params)
        params = self.param_reshaper.network_to_flat(self.state.params, self.padding).reshape(self.args.parts, -1).copy()
        self.server = jax.vmap(server.replace)(mean=params)
        self.grad_backlog = []
        self.update_backlog = []
        self.last_global_grad = None
        # self.last_client_states = None

        print('Number of Params:', sum(x.size for x in jax.jax.tree_util.tree_leaves(self.state.params)))

    def evofed_fitness(self, x, flat_states):
        fn_vmap_l2 = jax.vmap(l2, in_axes=(0, None))
        fn_vmap_parts = jax.vmap(fn_vmap_l2, in_axes=(0, 0))
        fn_vmap_clients = jax.vmap(fn_vmap_parts, in_axes=(None, 0))
        fitness = fn_vmap_clients(x, flat_states)
        fitness = jax.vmap(jax.vmap(self.fit_shaper.apply), in_axes=(None, 0))(x, fitness)
        return fitness

    # Linear Regression Fitness
    def linear_reg_fitness(self, x, flat_states):
        fn_vmap_parts = jax.vmap(gradient_descent, in_axes=(0, 0, None, None, None))
        fn_vmap_clients = jax.vmap(fn_vmap_parts, in_axes=(None, 0, None, None, None))
        w, b, loss = fn_vmap_clients(x, flat_states, self.args.reg_lr, self.args.reg_iter, self.args.pop_size)
        # fitness = jnp.concatenate([w, b], axis=2)
        for l in loss.mean(axis=(0,1)):
            wandb.log({'Linear Regression Loss': l})
        return w #fitness

    def close_form_linear_reg_fitness(self, x, flat_states):
        fn_vmap_parts = jax.vmap(closed_form_linear_regression)
        fn_vmap_clients = jax.vmap(fn_vmap_parts, in_axes=(None, 0))
        fitness = fn_vmap_clients(x, flat_states)
        # fitness = jnp.concatenate([w, b], axis=2)
        return fitness

    def projection_fitness(self, x, flat_states):
        fn_vmap_projection = jax.vmap(projection, in_axes=(0, None))
        fn_vmap_parts = jax.vmap(fn_vmap_projection, in_axes=(0, 0))
        fn_vmap_clients = jax.vmap(fn_vmap_parts, in_axes=(None, 0))
        fitness = fn_vmap_clients(x, flat_states)
        return fitness


    def run(self, rng: chex.PRNGKey):
        X, y = self.train_ds['image'], self.train_ds['label']
        # X = jnp.array(X)
        # y = jnp.array(y)
        if self.args.dataset == 'cifar10' or self.args.cpu_batching:
            X = jax.device_put(X, device=jax.devices('cpu')[0])
            y = jax.device_put(y, device=jax.devices('cpu')[0])
        for epoch in tqdm(range(1, self.args.n_rounds + 1)):
            rng, batch_rng, train_rng = jax.random.split(rng, 3)

            states, loss, acc, l1_loss, l2_loss = (jax.vmap(sl.train_epoch, in_axes=(None, 0, 0, None, None))(self.state, X, y, batch_rng, train_rng))
            for c_epoch in range(self.args.client_epoch):
                rng, client_rng = jax.random.split(rng)
                states, _loss, _acc, _l1, _l2 = jax.vmap(sl.train_epoch, in_axes=(0, 0, 0, None, None))(states, X, y, batch_rng, train_rng)
                wandb.log({
                    'Round': epoch,
                    'Local Epoch': epoch * self.args.client_epoch + c_epoch,
                    'Local Loss': _loss.mean(),
                    'Local Accuracy': _acc.mean(),
                    'Local L1 Loss': _l1.mean(),
                    'Local L2 Loss': _l2.mean(),
                })

            origin_state = self.server.mean
            target_states = jax.vmap(self.param_reshaper.network_to_flat, in_axes=(0, None))(states.params, self.padding)
            flat_states = target_states.reshape(self.args.n_clients, self.args.parts, -1).copy()
            grad_signals = jax.vmap(jax.vmap(get_diff), in_axes=(0, None))(flat_states, self.server.mean)

            if self.args.momentum_gain and self.last_global_grad is not None:
                momentum_gains = jax.vmap(jax.vmap(momentum_projection), in_axes=(0, None))(grad_signals, self.last_global_grad)
                grad_signals = grad_signals - momentum_gains
                momentum_gains = jnp.mean(momentum_gains, axis=0)
            # if self.args.recycle:
            #     t = jax.vmap(jax.vmap(get_diff), in_axes=(0, None))(flat_states, self.server.mean)
            #     self.grad_backlog.append(t)
            coded_states = grad_signals

            # sparse_index = jnp.nonzero(jax.vmap(jax.vmap(sparsify_index, in_axes=(0, None, None)), in_axes=(0, None, None))(grad_signals,self.args.percentage,grad_signals.size)[0])[0].shape


            if self.args.evofed:
                rng, rng_ask = jax.random.split(rng)
                x, self.server = jax.vmap(self.strategy.ask, in_axes=(None, 0, None))(rng_ask, self.server, self.es_params)
                e = jax.vmap(get_diff)(x, self.server.mean)
                coding_fn = self.evofed_fitness
                if self.args.linear:
                    coding_fn = self.linear_reg_fitness
                elif self.args.projection:
                    coding_fn = self.projection_fitness
                coded_states = coding_fn(e, grad_signals)


            if self.args.sparsify:
                coded_states = jax.vmap(sparsify, in_axes=(0, None))(coded_states, self.args.percentage)
                # flat_states = sparse_grad.reshape(self.args.n_clients, self.args.parts, -1).copy()
            if self.args.quantize:
                min_val = jax.vmap(jnp.min)(coded_states)
                max_val = jax.vmap(jnp.max)(coded_states)
                coded_states = jax.vmap(quantize, in_axes=(0, None))(coded_states, self.args.n_bits)
                # get unique values in coded state
                print(jnp.unique(coded_states))
                coded_states = jax.vmap(dequantize, in_axes=(0, 0, 0, None))(coded_states, min_val, max_val, self.args.n_bits)

            ## SERVER SIDE
            coded_state_parted = jnp.mean(coded_states, axis=0)
            # if self.args.sparsify:
            #     coded_state_parted = sparsify(coded_state_parted, self.args.percentage)

            if self.args.evofed:
                self.server = jax.vmap(self.strategy.tell, in_axes=(0, 0, 0, None))(x, coded_state_parted,
                                                                                    self.server, self.es_params)
                coded_state_parted = self.server.mean - origin_state
            if self.args.momentum_gain and self.last_global_grad is not None:
                coded_state_parted = coded_state_parted + momentum_gains
            self.last_global_grad = coded_state_parted

            grad_signal_padded = coded_state_parted.reshape(-1)

            if self.args.noise:
                rng, noise_rng = jax.random.split(rng)
                grad_signal_padded = add_noise(grad_signal_padded, self.args.noise_std, self.args.uniform_noise, noise_rng)

            # report_error(epoch, flat_state_padded, origin_state, target_states)
            flat_state_padded = origin_state.reshape(-1) + grad_signal_padded
            flat_state = flat_state_padded[:-self.padding] if self.padding > 0 else flat_state_padded

            self.state = self.state.replace(params=self.param_reshaper.reshape_single_net(flat_state))
            params = self.param_reshaper.network_to_flat(self.state.params, self.padding).reshape(self.args.parts, -1)
            self.server = self.server.replace(mean=params)
            if epoch % self.args.eval_every == 0:
                rng, eval_rng = jax.random.split(rng)
                self.eval_model(epoch, loss, acc, l1_loss, l2_loss, eval_rng)

        if self.args.save_model:
            if not os.path.exists(self.args.save_dir):
                os.makedirs(self.args.save_dir)
            model_path = os.path.join(self.args.save_dir, 'model.pkl')
            bytes_data = serialization.to_bytes(self.state.params)
            with open(model_path, 'wb') as f:
                f.write(bytes_data)
            model_art = wandb.Artifact('model', type='model')
            model_art.add_file(model_path, name='model.pkl')
            wandb.log_artifact(model_art)

    def eval_model(self, epoch, loss, acc, l1_loss, l2_loss, rng: chex.PRNGKey):
        test_loss, test_accuracy, class_accuracy, fairness_metric = sl.eval_model(self.state.params, self.test_ds, rng)
        comm = 2 * epoch * self.total_params * 64 / 8 / 1024 / 1024
        if self.args.evofed:
            comm = epoch * self.args.pop_size * self.args.parts * 64 / 8 / 1024 / 1024
        if self.args.sparsify:
            comm = 2 * epoch * self.total_params * (1 - self.args.percentage)  * 64 / 8 / 1024 / 1024
        wlog = {
            'Round': epoch,
            'Global Loss': test_loss,
            'Local Epoch': epoch * self.args.client_epoch,
            'Local Loss': loss.mean(),
            'Local Accuracy': acc.mean(),
            'Local L1 Loss': l1_loss.mean(),
            'Local L2 Loss': l2_loss.mean(),
            'Global Accuracy': test_accuracy,
            'Communication Cost': comm,
        }
        if self.args.fairness:
            wlog.update({
                'Min Class Accuracy': jnp.min(class_accuracy),
                'Max Class Accuracy': jnp.max(class_accuracy),
                'Diff Class Accuracy': jnp.max(class_accuracy) - jnp.min(class_accuracy),
                'Fairness Metric': fairness_metric,
            })
            wlog.update({f'Class {i} Accuracy': a for i, a in enumerate(class_accuracy)})
        wandb.log(wlog)

def run():
    wandb.init(project='MA-LoRA', config=config)
    args = wandb.config
    wandb.run.name = '{}-{} s{} -- {}'.format(args.dataset, name_me(args), args.seed, wandb.run.id)
    save_code()

    rng = jax.random.PRNGKey(args.seed)
    rng, rng_init, rng_run = jax.random.split(rng, 3)
    manager = TaskManager(rng_init, args)
    manager.run(rng_run)
    wandb.finish()

if __name__ == '__main__':
    run()