import os.path as osp
import pickle
from mpi4py import MPI
from tqdm import tqdm
import numpy as np

import jax
import jax.numpy as jnp
import wandb

from learned_optimization import checkpoints
from learned_optimization.outer_trainers import (
    gradient_learner,
    truncated_pes,
    truncation_schedule,
    full_es,
)

from meta_trainers import get_meta_trainer
from helpers import get_resume_ckpt, save_checkpoint, set_non_hashable_args, cast_to_bf16, Timing
import globals

def broadcast_wandb_name(wandb_run_name, axis_name='i'):
    """Broadcast the wandb run name from rank 0 to all other ranks."""
    # Convert the string to an array of chars (since JAX communicates arrays)
    name_chars = list(wandb_run_name.ljust(100))  # Pad the string to a fixed size
    name_array = jnp.array([ord(c) for c in name_chars])

    # Use jax.lax.psum to ensure rank 0's name is broadcast to all devices
    broadcast_name_array = jax.lax.psum(name_array, axis_name=axis_name)

    # Convert the array of integers back to a string
    broadcast_name = ''.join(chr(c) for c in broadcast_name_array if c != 0).strip()

    return broadcast_name


def meta_train(args):
    args = set_non_hashable_args(args)
    meta_trainer, meta_opt = get_meta_trainer(args)

    key = jax.random.PRNGKey(args.rank * 10000)
    key, key1 = jax.random.split(key)
    outer_trainer_state = meta_trainer.init(key1)

    globals.needs_state = args.needs_state
    globals.num_grads = args.num_grads
    globals.num_local_steps = args.num_local_steps
    globals.local_batch_size = args.local_batch_size[0]
    globals.use_pmap = args.use_pmap
    globals.num_devices = args.num_devices

    if args.use_pmap:
        assert args.num_grads % args.num_devices == 0, "The number of devices for parallelism should be a divisor of the number of clients (gradients)"
    
    if args.finetune:
        with open(args.test_checkpoint, "rb") as f:
            meta_params = pickle.load(f)
        

    run = None
    wandb_run_id = ''
    if args.from_checkpoint:
        dirname = osp.join("checkpoints", args.meta_train_name)
        ckpt = open(osp.join(dirname, "latest"), "r").readline().strip()
        outer_trainer_state = checkpoints.load_state(
            osp.join(dirname, "{}.ckpt".format(ckpt)), outer_trainer_state
        )
        if args.rank == 0:
            run = wandb.init(
                project=args.train_project,
                group=args.meta_train_name,
                config=vars(args),
            )
            
    elif args.auto_resume:
        ckpt = get_resume_ckpt("checkpoints", args.meta_train_name)

        if ckpt is not None:
            outer_trainer_state = checkpoints.load_state(
                osp.join(ckpt,"rank-{}_outer_trainer_state.ckpt".format(args.rank)), outer_trainer_state
            )
            if args.rank == 0:
                run = wandb.init(
                    project=args.train_project,
                    group=args.meta_train_name,
                    config=vars(args),
                    resume='must',
                    id=ckpt.split('/')[1][:8]
                )
                wandb_run_id = run.id
            
    
    if run == None:
        if args.rank == 0:
            run = wandb.init(
                project=args.train_project,
                group=args.meta_train_name,
                config=vars(args),
            )
            wandb_run_id = run.id


    # run.finish()
    # exit(0)

    if args.use_bf16:
        outer_trainer_state = cast_to_bf16(outer_trainer_state)

    wandb_run_id = MPI.COMM_WORLD.bcast(wandb_run_id, root=0)

    # Print the result from each process
    print(f"Rank {args.rank}: Wandb run name is {wandb_run_id}")

    i = None
    iteration = int(
        outer_trainer_state.gradient_learner_state.theta_opt_state.iteration
    )
    pbar = tqdm(
        range(iteration, args.num_outer_steps),
        initial=iteration,
        total=args.num_outer_steps,
        ascii=True,
        desc="Outer Loop",
        mininterval=0,  # update as often as possible
        miniters=1,      # update every iteration
        # dynamic_ncols=True
    )
    logging_task_name = args.task[0] if len(args.task) == 1 else "multi-task-with_" + args.task[0]


    meta_train_update, metric_all_reduce_time = [], []
    all_reduce_wait = None
    for i in range(iteration, args.num_outer_steps):
        
        key, key1 = jax.random.split(key)

        with Timing('meta train update',meta_train_update):
            outer_trainer_state, meta_loss, metrics = meta_trainer.update(
                outer_trainer_state, key1, with_metrics=True
            )
            # synchronize to get correct step time
            jax.experimental.multihost_utils.sync_global_devices('sync')

        # update truncation length
        for x in range(len(meta_trainer.gradient_estimators)):
            if type(meta_trainer.gradient_estimators[x]) == truncated_pes.TruncatedPES:
                meta_trainer.gradient_estimators[x].update_truncation_length(i)


        # Calculate local mean and max data time
        local_mean_data_time = np.mean(meta_trainer.gradient_estimators[0].truncated_step.timings[-50 // args.steps_per_jit:])
        local_total_time = np.sum(meta_trainer.gradient_estimators[0].truncated_step.timings[-50 // args.steps_per_jit:])
        local_max_data_time = np.max(meta_trainer.gradient_estimators[0].truncated_step.timings[-50 // args.steps_per_jit:])


        if all_reduce_wait is None:
            with Timing('AR time',metric_all_reduce_time):
                # All-reduce asynchronously to get mean and max across all devices
                data = np.array([local_max_data_time, meta_train_update[-1]])
                all_reduce_wait = MPI.COMM_WORLD.Iallreduce(MPI.IN_PLACE, data, op=MPI.MAX)
                max_data_time = -1.0
                meta_train_time = -1.0

        else:

            with Timing('AR time',metric_all_reduce_time):
                all_reduce_wait.Wait()
                max_data_time = data[0]
                meta_train_time = data[1]

                #overlap meta-training and AR
                data = np.array([local_max_data_time, meta_train_update[-1]])
                all_reduce_wait = MPI.COMM_WORLD.Iallreduce(MPI.IN_PLACE, data, op=MPI.MAX)


        



        if args.rank == 0:

            more_to_log = {
                    "iteration": i,
                    "meta loss": meta_loss,
                    "PES Gather" : Timing.run_times_dict['PES Gather'][-1],
                    "Global AR": Timing.run_times_dict['meta train all reduce'][-1],
                    'Unroll Time': Timing.run_times_dict['meta train unroll'][-1],
                    "AR metric time" : round(metric_all_reduce_time[-1], 4),
                    "meta iter time" : round(meta_train_time, 4),
                    "local data time mean" : round(local_mean_data_time, 7),
                    "local data time total" : round(local_total_time, 7),
                    "Data time max" : round(max_data_time, 7),
                    "learning rate" : meta_opt.__dict__.get(
                        "schedule_", lambda x: args.learning_rate
                    )(
                        outer_trainer_state.gradient_learner_state.theta_opt_state.iteration
                        - 1
                    ),
                }

            pbar.set_postfix({
                "meta loss" : round(float(meta_loss),2), #this has been all-reduced
                "Global AR" : more_to_log["Global AR"],
                "Metric AR" : more_to_log["AR metric time"],
                "PES Gather" : more_to_log["PES Gather"],
                "Iter T" : more_to_log["meta iter time"],
                "max Data T" : more_to_log["Data time max"],
                "Unroll T" : more_to_log['Unroll Time'],
                "L-Data T total" : more_to_log["local data time total"],
                # "Local Data T mean" : more_to_log["local data time mean"],
                "LR:" : round(more_to_log["learning rate"],5),
                
            })
            pbar.update(1)
            
            metrics.update(more_to_log)
            run.log(
                metrics
            )

            if (i + 1) % args.save_iter == 0 or i == 1: 

                #TODO: add support for saving meta-training checkpoints in parallel
                savepath = save_checkpoint(
                    prefix=wandb_run_id, i=i, args=args, outer_trainer_state=outer_trainer_state, rank=args.rank,
                )
                wandb.save(savepath)

        else:

            if (i + 1) % args.save_iter == 0 or i == 1: 
                save_checkpoint(
                    prefix=wandb_run_id, i=i, args=args, outer_trainer_state=outer_trainer_state, rank=args.rank
                )
                
        jax.experimental.multihost_utils.sync_global_devices('sync')

    


    if args.rank == 0:

        # Todo: check if this is a fix to error when resuming from final checkpoint
        if i is None:
            i = iteration

        savepath = save_checkpoint(
            prefix=run.id, i=i, args=args, outer_trainer_state=outer_trainer_state, rank=args.rank
        )

        wandb.save(savepath)
        run.finish()

    # all procs wait for wandb to finish            
    jax.experimental.multihost_utils.sync_global_devices('sync')

    exit(0)
        

