import os, gc

os.environ['OMP_NUM_THREADS'] = '1'

import numpy as np

import sys, os, wandb, argparse
sys.path.append("..")

from functools import partial

import torch
import torch.nn as nn

from copy import deepcopy

import optimizers
from models import get_model
from sparse_fedavg_exp import load_datasets, test


def get_model_norm(model):
    all_params = torch.cat([p.flatten() for p in model.parameters() if p.requires_grad])
    # Calculate the norm of the vector
    norm = torch.norm(all_params)
    return norm

def get_control_norm(states):
    all_states = torch.cat([state.flatten() for state in states.values()])
    # Calculate the norm of the vector
    norm = torch.norm(all_states)
    return norm
    
def get_control_sum(states):
    return {param_name: (sum(state['control'][param_name] for state in states)) for param_name in states[0]['control']}

def get_grad_norm(model):
    all_grads = torch.cat([p.grad.flatten() for p in model.parameters() if p.requires_grad])
    norm = torch.norm(all_grads)
    return norm

def get_sparsity(model):
    with torch.no_grad():
        num_params, num_zeros = 0, 0
        for param in model.parameters():
            if param.requires_grad:
                num_params += param.numel()
                num_zeros += torch.count_nonzero(param.round(decimals=10)).item()
        return 1 - num_zeros/ num_params

def train(config):
    if config and 'disable_wandb' in config and config['disable_wandb']:
        mode= 'disabled'
    else:
        mode= 'online'


    wandb.init(config=config, project=config['project'] if config else None, allow_val_change=True, mode=mode)
    # Disable randomized adaptation at the moment
    print("Randomization option ignored. Determined by optimizer choice.")
    wandb.config.update({'randomized' : wandb.config.optimizer in ('ProxSkip', 'ProxSkip_mod_steps')}, allow_val_change=True)

    repeat_exp = {"FedAvg":
                  {
                      "global": (1, 256, 47),
                      "local": (1e-1, 128, 20),
                      "comm_mod": (1e-1, 1, 20),
                      "final": (1e-1, 1, 20),
                  },
                  "ProxSkip": 
                  {
                      "global": (2.7,64, 165),
                      "local": (3e-1, 256, 30),
                      "comm_mod": (1.3, 256, 14),
                      "final": (1.3, 256, 14),
                  },
    }

    lr, local_steps, grad_clip_value = repeat_exp[wandb.config['optimizer']][wandb.config['prox_loc']]
    print(f"lr, local_steps and grad_clip_value overwritten to {lr}, {local_steps} and {grad_clip_value}.")
    wandb.config.update({'lr' : lr, 'local_steps': local_steps, 'grad_clip_value':grad_clip_value}, allow_val_change=True)
    
    config = dict(wandb.config.items())
    config['randomized'] = wandb.config.optimizer in ('ProxSkip', 'ProxSkip_mod_steps')
    config['rounds'] = 500 if wandb.config.optimizer == 'ProxSkip' and wandb.config.prox_loc == "global" else config['rounds']
    config['lr'], config['local_steps'], config['grad_clip_value'] = lr, local_steps, grad_clip_value
    print("We ignore that wandb ignores and just change the config directly")

    if 'seed' in config:
        rng = np.random.default_rng(config['seed'])
        torch.manual_seed(config['seed'])
    else:
        rng = np.random.default_rng()

    client_loaders, valloader, test_loader = load_datasets(config)

    # Taken from LogR, needs to bet adapted.
    l2_regularizer = config['weight_decay']

    device = torch.device("cuda")

    with torch.no_grad():
        model = get_model(config).to(device, dtype=torch.double)
        criterion = nn.CrossEntropyLoss(reduction='mean')

        num_models = len(client_loaders)
        num_clients_participating = int(config['fraction_fit']*num_models)
        clients_participated_at_least_once = set()
        print(f"Number of participating clients: {num_clients_participating}")
        train_model = deepcopy(model).to(device)
        avg_model = model.to(device)
        num_samples = sum(len(loader.dataset) for loader in client_loaders)

        lr = config['lr']
        comm_rounds = config['rounds']
        p = 1/config['local_steps']
        target_sparsity = config['target_sparsity']

        client_states = [{'control':{}, "par_hat":{}} for _ in range(num_models)]

        _, accuracy = test(config, model, test_loader, device)

        measured_sparsity = get_sparsity(model)
        wandb.log({'distance':1, 'sparsity':measured_sparsity, 'test':accuracy, 'communication_cost':0}, step=0, commit=False)

    previous_steps = config['local_steps']
    communication_cost = 0

    for round in range(int(comm_rounds)):
        
        train_loss, unregularized_train_loss = 0, 0
        grad_norm = []
        actual_steps = rng.geometric(p) if config['randomized'] else config['local_steps']
        
        clients_participating = sorted(rng.choice(config['num_clients'], num_clients_participating, replace=False))
        print(f"Clients participating: {clients_participating}")
        clients_participated_at_least_once.update(clients_participating)

        prev_model = deepcopy(avg_model)
        for key in avg_model.state_dict():
            avg_model.state_dict()[key].zero_()
        
        for cid in clients_participating:
            client_loader = client_loaders[cid]
            train_model.load_state_dict(prev_model.state_dict())

            if config['optimizer'] in ("ProxSkip"):
                optimizer = optimizers.ProxSkipClient(train_model.named_parameters(), lr=lr, weight_decay=0, client_state=client_states[cid], p = p, dual_lr=config['dual_lr'])
            else:
                optimizer = torch.optim.SGD(train_model.parameters(),lr=lr, weight_decay=0)  

            step = 0
            while step < actual_steps:
                for batch in client_loader:
                    if config['non_iid'] != 'dirichlet':
                        img, label = batch["img"].to(device), batch["label"].to(device)
                    else:
                        img, label = batch
                        img = img.to(device)
                        label = label.to(device)
                    
                    label_pred = train_model(img)
                    unregularized_loss = criterion(label_pred, label)
                    # loss = unregularized_loss
                    loss =  unregularized_loss + l2_regularizer/2*get_model_norm(model)**2/num_clients_participating
                    
                    optimizer.zero_grad()
                    loss.backward()

                    if config['grad_clip_value'] > 0:
                        nn.utils.clip_grad_norm_(train_model.parameters(), config['grad_clip_value'])
                    
                    # Optimizer step with pruning detail
                    if config['optimizer'] in ("ProxSkip"):
                        sparsity = target_sparsity if ('local' in config['prox_loc'] \
                                    or ('comm_mod' in config['prox_loc']  and step == actual_steps-1)) else None
                        global_mod_sparsity = target_sparsity if 'global_mod' in config['prox_loc'] and step == 0 else 0
                        
                        optimizer.step(save_par_hat=(step==actual_steps-1),
                                       sparsity=sparsity,
                                       global_mod_sparsity=global_mod_sparsity)
                    else:
                        optimizer.step()
                        if 'local' in config['prox_loc'] or ('comm_mod' in config['prox_loc'] and step == actual_steps-1):
                            optimizers.top_k_unstructured(train_model.parameters(), target_sparsity, groups=False)
                    
                    # Log values
                    if step == 0:
                        with torch.no_grad():
                            grad_norm.append(get_grad_norm(train_model).detach().cpu().numpy())
                        unregularized_train_loss += unregularized_loss.item()
                        train_loss += loss.item() 
                    step += 1
                    if step >= actual_steps:
                        break
            
            with torch.no_grad():
                # Sum up the parameters every round 
                for param_name, param in train_model.named_parameters():
                    avg_model.state_dict()[param_name].add_(param.data, alpha=1/num_clients_participating)
        
        # Can be used to make ProxSkip get the control variates from the actual number of steps taken in the last round
        previous_steps = actual_steps
        train_loss = train_loss/num_clients_participating
        unregularized_train_loss = unregularized_train_loss/num_clients_participating
        wandb.log({'loss':train_loss, 'unregularized loss':unregularized_train_loss}, step=round, commit=True)
        
        with torch.no_grad():
            if "comm" in config['prox_loc'] or 'local' in config['prox_loc']:
                communication_cost += 1 - target_sparsity
            else:
                communication_cost += 1

            # Averaging alreayd done earlier
            # for param_name, param in avg_model.named_parameters():
            #     param_sum = sum(models[cid].state_dict()[param_name] for cid in clients_participating)
            #     param.data.copy_(param_sum / num_clients_participating)
            
            if 'global' in config['prox_loc'] and ('global_mod' not in config['prox_loc'] or config['optimizer'] == 'FedAvg'):
                optimizers.top_k_unstructured(avg_model.parameters(), target_sparsity, groups=False)

            if "ProxSkip" in config['optimizer']:
                # Update control of participating clients
                for cid in clients_participating:
                    # train_model = models[cid]
                    optimizer = optimizers.ProxSkipClient(avg_model.named_parameters  (), lr=lr, weight_decay=0, client_state=client_states[cid], p = p, dual_lr=config['dual_lr'])
                    optimizer.update_control(1/p if ['optimizer'] == "ProxSkip" else previous_steps)

                # Report average norm
                states_initialised = [client_states[cid] for cid in clients_participated_at_least_once]
                control_avg_norm = sum(get_control_norm(state['control']).item() for state in states_initialised) / num_clients_participating
                control_sum = get_control_norm(get_control_sum(states_initialised)).item()
                avg_norm_control_vars = sum([get_control_norm(state['control']).item() for state in states_initialised])/num_clients_participating
            else:
                control_avg_norm, control_sum, avg_norm_control_vars = 0, 0, 0
            
            eval_model = deepcopy(avg_model)
            if config['prox_loc'] in ('final'):
                optimizers.top_k_unstructured(eval_model.parameters(), target_sparsity, groups=False)

            sparsity = get_sparsity(eval_model)

            # Prune the model before evaluation for fair comparison 
            optimizers.top_k_unstructured(eval_model.parameters(), target_sparsity, groups=False)
            params_norm = get_model_norm(eval_model).item()

            if (round +1 )% config['eval_every'] == 0 or round == 0:
                _, accuracy = test(config, eval_model, test_loader, device)

                avg_grad_norm = sum(grad_norm)/len(grad_norm)

                wandb.log({'test':accuracy, 'sparsity':sparsity,
                            'params_norm':params_norm, 'avg_grad_norm':avg_grad_norm,
                            'control_avg_norm': control_avg_norm, 'control_sum': control_sum, "avg_norm_control":avg_norm_control_vars,
                            "communication_cost": communication_cost}, step=round+1, commit=False)
                print(f"Comm Rounds {round + 1}, Avg Train Loss: {train_loss:.4g}, Test Accuracy: {100*(accuracy):.1f}%, Sparsity: {sparsity:.2f}, LR: {lr:.1g}, Steps {int(actual_steps)}, Sum h {control_sum:.2g}, Avg h {control_avg_norm:.2g}, cost {communication_cost:.2f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="DNN experiments")
    parser.add_argument('-t','--target_sparsity', help="Manual Mode: Provide a target sparsity",type=float, default=0)
    parser.add_argument('-s','--sweepID',help="Run a wandb sweep. Sweep ID needs to be provided.")
    parser.add_argument('--project', default="Debug")
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--lr', type=float, default=1e-1)
    parser.add_argument('--lr_schedule', default="constant", choices=['constant', 'root', 'halving'])
    parser.add_argument('--lr_halving', default=100, type=int)
    parser.add_argument('--dual_lr', type=float, default=1)
    parser.add_argument('--local_steps', type=int, default=5)
    parser.add_argument('--rounds', type=int, default=10)
    parser.add_argument('--num_clients', type=int, default=10)
    parser.add_argument('--fraction_fit', type=float, default=1)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--eval_every', type=int, default=1)
    parser.add_argument('--randomized', action='store_true')
    parser.add_argument('--disable_wandb', action='store_true')
    parser.add_argument('--prox_loc', type=str, default="comm_mod", choices=['global',"local","comm_mod",'final'])
    parser.add_argument('--optimizer', type=str, default="ProxSkip", choices=["ProxSkip", "FedAvg"])
    parser.add_argument('--model', type=str, default="ResNet18", choices=['CNN',"ResNet18","ConvNet6", "Linear"])
    parser.add_argument('--dropout', type=float,default=0.0)
    parser.add_argument('--non_iid', type=str, choices=['iid', 'isik', 'dirichlet'], help="Set to an int >0 to use the partitioning from the Isik 2023 paper", default='dirichlet')
    parser.add_argument('--non_iid_param', default=0.3, type=float, help="Either dirichlet alpha or set to an int >0 to use the partitioning from the Isik 2023 paper")
    parser.add_argument('--bn_track_running_stats', type=bool, help="Only relevant for ResNet18. True is not implemented atm.", default=False)
    parser.add_argument('--transforms', type=bool, help="Use transforms in pipeline.", default=True)
    parser.add_argument('--grad_clip_value', type=float, default=10, help="Maximum norm of the gradient. 0 means no gradient clipping.")
    
    args = parser.parse_args()

    if args.sweepID:
        wandb.agent(args.sweepID, partial(train, None), project=args.project)
    else:
        train(vars(args))

        
