import argparse
import time
from functools import partial, reduce

import jax
import jax.numpy as jnp
import numpy as np
import pprint
import wandb
from tqdm import tqdm

import globals
from batched_image_augmentations import (
    autoaugment_cifar10,
    cutmix,
    mixup,
    normalize_images,
    random_crop,
    random_flip
)
from helpers import (
    Timing,
    cast_to_bf16,
    cast_to_fp32,
    set_non_hashable_args
)
from optimizers import get_optimizer
from tasks import get_task
from helpers import print_rank_0

is_leaf = lambda x : reduce(np.logical_and, [type(x1) != dict for x1 in x.values()])

def add_prefix(prefix,s):
    if prefix != '':
        prefix = prefix + '/'
    return prefix + s

def get_mup_lrs(state,prefix):
    d = {}
    for k,v in state.items():
        if is_leaf(v):
            d[add_prefix(prefix,k)] = v
        else:
            for kk,vv in get_mup_lrs(v,k).items():
                d[add_prefix(prefix,kk)] = vv
    
    d = {k.replace('/mup_lrs',''):v for k,v in d.items()}
    return d
# lrs = get_mup_lrs({k:{'mup_lrs':v['mup_lrs']} for k,v in state.items() if 'mup_lrs'in v.keys()}, 
#                         prefix='')

def rename_batch(batch):
    label_map = {'obs':'image',
                    'target':'label',
                    'image':'image',
                    'label':'label'}
    
    return {label_map[k]:v for k,v in batch.items()}

def count_parameters(params):
    return sum(jnp.size(param) for param in jax.tree_util.tree_leaves(params))

def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key+'_mean', jnp.mean(v).item()))
            items.append((new_key+'_std', jnp.std(v).item()))
            items.append((new_key+'_max', jnp.max(v).item()))
            items.append((new_key+'_min', jnp.min(v).item()))
            items.append((new_key+'_2norm', jnp.linalg.norm(v,ord=2).item()))

    return dict(items)

def get_params_and_state(needs_state, task, key):
    if needs_state:
        print_rank_0("callling init_with_state in get_params_and_state")
        return task.init_with_state(key)
    else:
        return task.init(key), None



def benchmark(args, sweep=False):
    if sweep:
        if args.rank == 0:
            run = wandb.init(project=args.test_project, group=args.name, config=vars(args))   
            args = argparse.Namespace(**run.config)
            override = [x for x in args.__dict__.keys() if '__' in x]
            override_config = {k: args.__dict__[k] for k in override}
        else:
            print("type(args): ", type(args))
            # args = argparse.Namespace(**args)

        # Broadcast override keys from rank 0 to all ranks
        if args.world_size > 1:
            if args.rank == 0:
                override_to_broadcast = override_config
            else:
                override_to_broadcast = None
            
            # Use MPI to broadcast the override list from rank 0 to all ranks
            from mpi4py import MPI
            override_to_broadcast = MPI.COMM_WORLD.bcast(override_to_broadcast, root=0)
            override_config = override_to_broadcast
            
            # Synchronize to ensure all ranks have received the override list
            MPI.COMM_WORLD.Barrier()

            
        args.num_runs = 1
        
        # Apply overrides
        print_rank_0("Overriding sweep args:")
        for key, value in override_config.items():
            print_rank_0(f"Setting {key} to {value}")
            parts = key.split('__')
            target = args
            for i, part in enumerate(parts):
                if i == len(parts) - 1:
                    target[part] = value
                else:
                    parent = target
                    target = target.__dict__.get(part) if i == 0 else target.get(part)
        
        # Update wandb config with the overridden args
        if args.rank == 0:
            run.config.update(vars(args), allow_val_change=True)


    
    args = set_non_hashable_args(args)
    # Set up globals used in truncated step for benchmarking
    globals.needs_state = args.needs_state
    globals.num_grads = args.num_grads
    globals.num_local_steps = args.num_local_steps
    globals.local_batch_size = args.local_batch_size
    globals.use_pmap = args.use_pmap
    globals.num_devices = args.num_devices

    key = jax.random.PRNGKey(args.seed)
    task = get_task(args)[0]
    # test_task = get_task(args, is_test=True)

    key, key1 = jax.random.split(key)
    params, state = get_params_and_state(args.needs_state, task, key1)
    # print("Task: ", args.task[0], state)

    # print('params',jax.tree_util.tree_map(lambda x: x.shape, params))
    # print('state',jax.tree_util.tree_map(lambda x: x if type(x) in [float,int] else x.shape, state))
    # exit(0)



    # if args.use_bf16:
    #     params = cast_to_bf16(params)
    #     state = cast_to_bf16(state)
    # else:
    # params = cast_to_fp32(params)
    # state = cast_to_fp32(state)


    print_rank_0("====================================================================================")
    num_params_m = count_parameters(params)/1e6
    print_rank_0("Model parameters (M): ", num_params_m)
    num_tensors = len(jax.tree_util.tree_leaves(params))
    print_rank_0("Number of tensors: ", num_tensors)
    print_rank_0("====================================================================================")

    args.model_num_params = num_params_m
    args.model_num_tensors = num_tensors

    print_rank_0("params:")
    if args.rank == 0:
        pprint.pprint(jax.tree_util.tree_map(lambda x: x.shape, params))



    
    if state is not None:
        try:
            lrs = state['mup_lrs_to_use']
            set_diff = set(lrs.keys()) - set(params.keys())

            assert len(lrs) == len(params), f"Number of learning rates ({len(lrs)}) should be equal to number of parameters ({len(params)}), but differed by: " + str("; ".join(set_diff))
            assert set(lrs.keys()) == set(params.keys()), "Learning rates should have the same keys as parameters"
            args.runtime_mup_lrs = lrs
            print_rank_0("Set rruntime_mup_lrs")
        except KeyError as e:
            # print(state['mup_lrs_to_use'])
            print_rank_0("No mup_lrs_to_use in state, for task "+args.task[0])
    else:
        print_rank_0("State is None for task "+args.task[0])


    opt, update = get_optimizer(args)

    if args.use_pmap:
        assert args.num_grads % args.num_devices == 0, "The number of devices for pmap should be a multiple of the number of clients (gradients)"


    # import pdb; pdb.set_trace()
    test_acc=0
    print_rank_0('\nstarting loop')
    for _ in tqdm(range(args.num_runs), ascii=True, desc="Outer Loop", disable=args.rank != 0):
        if not sweep and args.rank == 0:
            run = wandb.init(project=args.test_project, group=args.name, config=vars(args))
        
        if _ > 0:
            params, state = get_params_and_state(args.needs_state, task, key1)
            # if args.use_bf16:
            #     params = cast_to_bf16(params)
            #     state = cast_to_bf16(state)
            # else:
            params = cast_to_fp32(params)
            state = cast_to_fp32(state)
        
        opt_state = opt.init(params, model_state=state, num_steps=args.num_inner_steps)
        if args.use_localsgd_batches:
            try:
                local_opt = opt.get_local_optimizer(task.get_mup_state({})['mup_lrs_to_use'])
            except AttributeError:
                local_opt = opt.get_local_optimizer(None)
            local_inner_opt_state = local_opt.init(params, model_state=state)
            local_inner_opt_state = jax.tree_util.tree_map(
                    lambda x: jnp.stack([x] * args.num_grads),
                    local_inner_opt_state)
        prev_params = params

        pbar = tqdm(
            range(args.num_inner_steps),
            initial=0,
            total=args.num_inner_steps,
            ascii=True,
            desc="Inner Loop",
            disable=args.rank != 0
        )
        train_load_time, grad_time, stepl, test_time = [],[],[],[]
        for iteration in pbar:

            # update
            with Timing('get traing batch', train_load_time):
                batch = rename_batch(next(task.datasets.train))


            key, key1 = jax.random.split(key)


            if 'cifar' in args.task[0] and False:
                batch_images,batch_labels = batch['image'], batch['label']
                
                key, subkey = jax.random.split(key)
                batch_images, batch_labels = random_flip(batch_images, batch_labels, subkey)

                # autoaugment_cifar10
                key, subkey = jax.random.split(key)
                batch_images, batch_labels = autoaugment_cifar10(batch_images, batch_labels, subkey)


                mean_cifar10 = (0.49139968, 0.4821584,  0.44653094)  # CIFAR-10 dataset mean (per channel)
                std_cifar10 = (0.24703221, 0.24348514, 0.26158786)

                batch_images = normalize_images(batch_images,mean_cifar10,std_cifar10)


                # key, subkey = jax.random.split(key)
                # batch_images, batch_labels = random_crop(batch_images, batch_labels, subkey)

                # key, subkey = jax.random.split(key)
                # batch_images, batch_labels = mixup(batch_images, batch_labels, subkey)

                # key, subkey = jax.random.split(key)
                # batch_images, batch_labels = cutmix(batch_images, batch_labels, subkey)

                batch = {'image':batch_images,'label':batch_labels}
            elif 'imagenet' in args.task[0] and False:
                
                batch_images,batch_labels = batch['image'], batch['label']
                
                key, subkey = jax.random.split(key)
                batch_images, batch_labels = random_flip(batch_images, batch_labels, subkey)

                # autoaugment_cifar10
                key, subkey = jax.random.split(key)
                batch_images, batch_labels = autoaugment_cifar10(batch_images, batch_labels, subkey)

                mean_inet = (0.485, 0.456, 0.406) 
                std_inet = (0.229, 0.224, 0.225)

                batch_images = normalize_images(batch_images,mean_inet,std_inet)


                # key, subkey = jax.random.split(key)
                # batch_images, batch_labels = random_crop(batch_images, batch_labels, subkey)

                # key, subkey = jax.random.split(key)
                # batch_images, batch_labels = mixup(batch_images, batch_labels, subkey)

                # key, subkey = jax.random.split(key)
                # batch_images, batch_labels = cutmix(batch_images, batch_labels, subkey)

                batch = {'image':batch_images,'label':batch_labels}




            with Timing('fw bw full', grad_time):
                # opt_state, loss, grad, aux = update(opt_state, key1, batch)
                if args.use_localsgd_batches:
                    opt_state, local_inner_opt_state, loss, grad = update(opt_state, local_inner_opt_state, key1, batch)
                else:
                    opt_state, loss, grad = update(opt_state, key1, batch)
                # print(aux)
                # print(grad)
                # exit(0)

                # if aux:
                #     to_log = aux
                # else:
                to_log = {
                        "train loss": loss,
                    }

            params = opt.get_params(opt_state)
            state = opt.get_state(opt_state)


            with Timing('test',test_time):
                #test loss and accuracy if implemented
                if not args.skip_test \
                   and (iteration % args.test_interval == 0 \
                        or iteration == 0 \
                        or iteration == args.num_inner_steps-1):
                    try:
                        test_batch = rename_batch(next(task.datasets.test))

                        if 'cifar' in args.task[0] and False:
                            mean_cifar10 = (0.49139968, 0.4821584,  0.44653094)  # CIFAR-10 dataset mean (per channel)
                            std_cifar10 = (0.24703221, 0.24348514, 0.26158786)

                            test_batch = {
                                'image':normalize_images(test_batch['image'],mean_cifar10,std_cifar10),
                                'label':test_batch['label']    
                            }
                        elif 'imagenet' in args.task[0] and False:
                            mean_inet = (0.485, 0.456, 0.406) 
                            std_inet = (0.229, 0.224, 0.225)
                            test_batch = {
                                'image':normalize_images(test_batch['image'],mean_inet,std_inet),
                                'label':test_batch['label']    
                            }


                        # if args.use_bf16:
                        #     test_batch = cast_to_bf16(test_batch)
                        key, key1 = jax.random.split(key)

                        if args.needs_state:
                            state = opt.get_state(opt_state)
                            test_loss, test_acc = task.loss_and_accuracy_with_state(params, state, key1, test_batch)
                        else:
                            test_loss, test_acc = task.loss_and_accuracy(params, key1, test_batch)

                        test_log = {
                            "test loss": test_loss,
                            "test accuracy": test_acc,
                        }
                    except AttributeError as e:
                        Warning("test_task does not have loss_and_accuracy method, defaulting to loss")
                        key, key1 = jax.random.split(key)
                        if args.needs_state:
                            state = opt.get_state(opt_state)
                            test_loss, state = task.loss_with_state(params, state, key1, test_batch)
                        else:
                            test_loss = task.loss(params, key1, test_batch)

                    # All-reduce mean the test loss across all processes
                    if args.world_size > 1:
                        # Use jax.pmap with pmean to reduce across devices
                        def reduce_mean_across_devices(x):
                            return jax.lax.pmean(x, axis_name='i')
                        
                        # Add a dimension to test_loss before pmean
                        test_loss = jnp.expand_dims(test_loss, axis=0)
                        test_loss = jax.pmap(reduce_mean_across_devices, axis_name='i')(test_loss)
                        # Remove the dimension after pmean
                        test_loss = jnp.squeeze(test_loss)

                    test_log = {"test loss": test_loss}
                    
                    to_log.update(test_log)
                else:
                    test_loss = 0


            if args.rank == 0:
                pbar.set_postfix({
                    "data":round(train_load_time[-1],4),
                    "fwbw":round(Timing.run_times_dict["fw bw"][-1],4),
                    "opt":round(Timing.run_times_dict["optimizer step"][-1],4),
                    "AR":round(Timing.run_times_dict["AR"][-1],4),
                    "test":round(test_time[-1],4),
                    "train loss":round(float(loss),2),
                    "test loss":round(float(test_loss),2) if not args.skip_test else 0,
                    "test acc":round(float(test_acc),2) if not args.skip_test else 0,
                    "LR": opt.get_current_lr(iteration) if args.optimizer.lower() == 'sgd' else 0,
                })

                # log
                to_log.update({
                    "optimizer step": Timing.run_times_dict["optimizer step"][-1],
                    "fwbw time": Timing.run_times_dict["fw bw"][-1],
                    "AR time":Timing.run_times_dict["AR"][-1],
                    
                })

            # to_log.update(flatten_dict(grad, parent_key='', sep='_'))
            # to_log.update(flatten_dict(jax.tree_util.tree_map(lambda x,y:x-y,prev_params,params), parent_key='delta', sep='_'))
            

            if args.log_activations and args.rank == 0:


                if iteration == 0:
                    idxkey = 'mlp' if 'mlp' in state else 'mu_mlp'
                    # initial_state = state
                    initial_tensors_only = {k:v for k,v in state[idxkey].items() if ('act' in k or 'logit' in k) and 'l1' not in k}

                idxkey = 'mlp' if 'mlp' in state else 'mu_mlp'
                to_log.update({k:v.item() for k,v in state[idxkey].items() if ('act' in k or 'logit' in k) and 'l1' in k})

                tensors_only = {k:v for k,v in state[idxkey].items() if ('act' in k or 'logit' in k) and 'l1' not in k}
                std_delta = jax.tree_util.tree_map(lambda x,y : jnp.std(x - y), tensors_only, initial_tensors_only)
                to_log.update({k+'_std_delta':v.item() for k,v in std_delta.items()})


            if args.rank == 0:
                run.log(to_log)

            prev_params = params
        
        if args.rank == 0:
            run.finish()


def sweep(args):
    import os
    os.environ['WANDB_LOG_LEVEL'] = 'debug'


    args.SWEEP_CONTINUE = True
    if args.rank == 0:

        for k,v in args.__dict__.items():
            if type(v) == list:
                print(k,type(v))

        print(args.sweep_config)

        if args.sweep_id is None:
            args.sweep_id = wandb.sweep(
                sweep=args.sweep_config, project=args.test_project
            )

        wandb.agent(args.sweep_id, partial(benchmark, args, True), project=args.test_project)
    else:

        while args.SWEEP_CONTINUE:
            benchmark(args, True)
