import os

os.environ['OMP_NUM_THREADS'] = '1'

import numpy as np

import sys, os, joblib, wandb, argparse
sys.path.append("..")

from functools import partial

from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn

from copy import deepcopy


import optimizers, femnist_dataset

OBJECT_PATH = "REMOVED"

def load_data(num_clients=3220):
    train_path = "REMOVED"
    test_path = "REMOVED"
    scaler = joblib.load(OBJECT_PATH + "scaler.joblib")

    test_data = femnist_dataset.FEMNISTDataset(test_path, load_tensors=True, scaler=scaler)
    train_data = femnist_dataset.FEMNISTDataset(train_path, load_tensors=True, scaler=scaler)

    train_data.images = train_data.images[:num_clients]
    train_data.labels = train_data.labels[:num_clients]
    train_data.X = torch.concatenate(train_data.images)
    train_data.y = torch.concatenate(train_data.labels)

    return train_data, test_data

def get_classifier(config, train_data):
    if config['num_clients'] == 3220:
        clf = joblib.load(OBJECT_PATH + "clf.joblib")
    else:
        l2_regularizer = 1e-4
        clf = LogisticRegression(C=1/l2_regularizer/train_data.X.shape[0], penalty="l2", solver="lbfgs", tol=1e-15, max_iter=int(1e5), fit_intercept=False)
        clf.fit(train_data.X, train_data.y)
    return clf

def train(config):
    wandb.init(config=config, project=config['project'] if config else None, allow_val_change=True)

    # Disable randomized adaptation at the moment
    print("Randomization option ignored. Determined by optimizer choice.")
    wandb.config.update({'randomized' : wandb.config.optimizer in ('ProxSkip', 'ProxSkip_mod_steps')}, allow_val_change=True)
    config = dict(wandb.config.items())

    if 'seed' in config:
        rng = np.random.default_rng(config['seed'])
        torch.manual_seed(config['seed'])
    else:
        rng = np.random.default_rng()
        
    # repeat_exp = {"FedAvg":
    #               {
    #                   "global": (39, 12),
    #                   "local": (25, 3),
    #                   "comm_mod": (25, 3),
    #                   "final": (25, 3),
    #               },
    #               "ProxSkip": 
    #               {
    #                   "global": (12,44),
    #                   "local": (31, 70),
    #                   "comm_mod": (6.7, 25),
    #                   "final": (8, 3),
    #               },
    # }
    

    # config['lr'], config['local_steps'] = repeat_exp[config['optimizer']][config['prox_loc']]
    # print(f"lr and local_steps overwritten to {config['lr']} and {config['local_steps']}.")

    train_data, test_data = load_data(config['num_clients'])

    # Hard coded, because the solution was obtained for this value
    l2_regularizer = 1e-4

    device = torch.device("cuda")
    clf = get_classifier(config, train_data)
    x_sol = torch.tensor(clf.coef_, device=device, requires_grad=False)

    batch_size= config['batch_size']
    client_loaders = train_data.get_client_loaders(batch_size)
    test_loader = test_data.get_full_loader(1024)

    with torch.no_grad():
        coeffs = nn.Linear(785, 10, dtype=torch.float64, bias=False)
        if config['init_zeros']:
            torch.nn.init.zeros_(coeffs.weight)
        model = coeffs
        criterion = nn.CrossEntropyLoss(reduction='sum')

        num_models = len(client_loaders)
        models = [deepcopy(model).to(device) for i in range(num_models)]
        num_samples = sum(len(loader.dataset) for loader in client_loaders)

        params = next(models[0].parameters()).data.detach()
        initial_distance = torch.norm(x_sol- params, 'fro')**2

        lr = config['lr']
        comm_rounds = config['rounds']
        p = 1/config['local_steps']
        target_sparsity = config['target_sparsity']

        client_states = [{'control':{}, "par_hat":{}} for _ in range(num_models)]

        # Get the accuracy of the initial model
        correct = 0
        for img, labels in test_loader:
            correct += (model(img).argmax(1) == labels).type(torch.float).sum().item()
        correct /= len(test_loader.dataset)

        measured_sparsity = 1 - (torch.count_nonzero(params.round(decimals=10))/params.numel()).item()
        wandb.log({'distance':1, 'sparsity':measured_sparsity, 'test':correct, 'communication_cost':0}, step=0, commit=False)

    previous_steps = config['local_steps']
    communication_cost = 0

    for round in range(int(comm_rounds)):
        
        train_loss, unregularized_train_loss, correct = 0, 0, 0
        grad_norm = []
        actual_steps = rng.geometric(p) if config['randomized'] else config['local_steps']
        
        for cid, client_loader in enumerate(client_loaders):
            model = models[cid]

            if config['optimizer'] in ("ProxSkip", "ProxSkip_mod_steps"):
                optimizer = optimizers.ProxSkipClient(model.named_parameters(), lr=lr, weight_decay=0,#l2_regularizer/num_models,
                                                            client_state=client_states[cid], p = p,)
            else:
                optimizer = torch.optim.SGD(model.parameters(),lr=lr, weight_decay=0)#l2_regularizer/num_models)   

            for step in range(int(actual_steps)):
                for img, label in client_loader:
                    if len(client_loader) != 1:
                        raise ValueError("Pruning is only implemented for single batch")
                    # Supposed to be a single batch, otherwise the pruning will do weird stuff?
                    img, label = img.to(device), label.to(device)
                    
                    label_pred = model(img)
                    unregularized_loss = criterion(label_pred, label)/num_samples
                    loss =  unregularized_loss + l2_regularizer/2*torch.norm(model.weight, 'fro')**2/num_models
                    
                    optimizer.zero_grad()
                    loss.backward()
                    
                    # Optimizer step with pruning detail
                    if config['optimizer'] in ("ProxSkip", "ProxSkip_mod_steps"):
                        sparsity = target_sparsity if ('local' in config['prox_loc'] \
                                    or ('comm_mod' in config['prox_loc']  and step == actual_steps-1)) else None
                        global_mod_sparsity = target_sparsity if 'global_mod' in config['prox_loc'] and step == 0 else 0
                        
                        control_p_factor = 1/p if ['optimizer'] == "ProxSkip" else previous_steps
                        
                        optimizer.step(actual_steps=control_p_factor if step== 0 else None, 
                                       save_par_hat= (step==actual_steps-1),
                                       sparsity=sparsity,
                                       global_mod_sparsity=global_mod_sparsity)
                    else:
                        optimizer.step()
                        if 'local' in config['prox_loc'] or ('comm_mod' in config['prox_loc'] and step == actual_steps-1):
                            optimizers.top_k_unstructured(model.parameters(), target_sparsity, groups=False)
                    
                    # Log values
                    if step == 0:
                        with torch.no_grad():
                            grad_norm.append(torch.norm(model.weight.grad).cpu().numpy())
                        unregularized_train_loss += unregularized_loss.item()
                        train_loss += loss.item() 
            
            if 'comm' in config['prox_loc']:
                optimizers.top_k_unstructured(model.parameters(), target_sparsity, groups=False)
        
        # Can be used to make ProxSkip get the control variates from the actual number of steps taken in the last round
        previous_steps = actual_steps
        wandb.log({'loss':train_loss, 'unregularized loss':unregularized_train_loss}, step=round, commit=True)

        
        with torch.no_grad():
            if "comm" in config['prox_loc'] or 'local' in config['prox_loc']:
                communication_cost += 1 - target_sparsity
            else:
                communication_cost += 1

            avg_model = models[0]
            for param_name, param in avg_model.named_parameters():
                param_sum = sum(model.state_dict()[param_name] for model in models)
                param.copy_(param_sum / num_models)

            if "ProxSkip" in config['optimizer']:
                control_avg_norm = sum(torch.norm(state['control']['weight']).item() for state in client_states) / num_models
                control_sum = torch.norm(sum(state['control']['weight'] for state in client_states)).item()
            else:
                control_avg_norm, control_sum = 0, 0
            
            if 'global' in config['prox_loc'] and ('global_mod' not in config['prox_loc']  or config['optimizer'] == 'FedAvg'):
                optimizers.top_k_unstructured(avg_model.parameters(), target_sparsity, groups=False)

            models = [avg_model] + [deepcopy(avg_model).to(device) for i in range(1, num_models)]
            
            eval_model = deepcopy(avg_model)
            if config['prox_loc'] in ('global_mod', 'final', 'comm_mod_global_mod'):
                optimizers.top_k_unstructured(eval_model.parameters(), target_sparsity, groups=False)

            params = next(eval_model.parameters()).data.detach()
            sparsity = 1 - torch.count_nonzero(params.round(decimals=10))/params.numel()

            # Prune the model before evaluation for fair comparison 
            optimizers.top_k_unstructured(eval_model.parameters(), target_sparsity, groups=False)
            params_norm = torch.norm(params, 'fro').item()
            distance = torch.norm(x_sol - params, 'fro')**2 /initial_distance

        if (round +1 )% config['eval_every'] == 0 or round == 0:
            for img, labels in test_loader:
                img = img.to(device)
                labels = labels.to(device)
                
                preds = eval_model(img)
                correct += (preds.argmax(1) == labels).type(torch.float).sum().item()

            if config['optimizer'] in ("ProxSkip", "ProxSkip_mod_steps"):
                with torch.no_grad():
                    avg_norm_control_vars = sum([ torch.norm(client['control']['weight']).cpu().item() for client in client_states])/num_models
                wandb.log({"avg_norm_control":avg_norm_control_vars}, commit=False, step=round+1)
            avg_grad_norm = sum(grad_norm)/len(grad_norm)

            correct /= len(test_loader.dataset)
            wandb.log({'test':correct, 'distance':distance, 'sparsity':sparsity,
                        'params_norm':params_norm, 'avg_grad_norm':avg_grad_norm,
                        'control_avg_norm': control_avg_norm, 'control_sum': control_sum,
                        "communication_cost": communication_cost}, step=round+1, commit=False)
            print(f"Comm Rounds {round + 1}, Avg Train Loss: {train_loss:.4g}, Test Accuracy: {100*(correct):.1f}%, Distance: {distance:.5f}, Sparsity: {sparsity:.2f}, LR: {lr:.1g}, Steps {int(actual_steps)}, Sum h {control_sum:.2g}, Avg h {control_avg_norm:.2g}, cost {communication_cost:.2f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Logistic Regression experiments")
    parser.add_argument('-t','--target_sparsity', help="Manual Mode: Provide a target sparsity",type=float, default=0)
    parser.add_argument('-s','--sweepID',help="Run a wandb sweep. Sweep ID needs to be provided.")
    parser.add_argument('--project', default="Debug")
    parser.add_argument('--lr', type=float, default=10)
    parser.add_argument('--local_steps', type=int, default=5)
    parser.add_argument('--rounds', type=int, default=10)
    parser.add_argument('--num_clients', type=int, default=3220)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--eval_every', type=int, default=1)
    parser.add_argument('--randomized', action='store_true')
    parser.add_argument('--init_zeros', action='store_true')
    parser.add_argument('--prox_loc', type=str, default="comm_mod", choices=['global',"local","comm","comm_mod",'global_mod','local_global','comm_mod_global','comm_global_mod','final', 'comm_mod_global_mod', 'local_global_mod','comm_global'])
    parser.add_argument('--optimizer', type=str, default="ProxSkip", choices=["ProxSkip", "FedAvg", "ProxSkip_mod_steps"])
    args = parser.parse_args()

    if args.sweepID:
        wandb.agent(args.sweepID, partial(train, None), project=args.project)
    else:
        train(vars(args))

        
