import os

os.environ['OMP_NUM_THREADS'] = '1'

import numpy as np
import pandas as pd
import wandb, argparse
from functools import partial

import os, glob

from sklearn.linear_model import Ridge, ElasticNet
from sklearn.preprocessing import MaxAbsScaler

from scipy.linalg import norm


from proxop import L1Norm

import objectives

DATA_PATH = 'REMOVED' 

class BlogPostFederated:
    def __init__(self, Xs, ys, mu):
        self.Xs = [X.values for X in Xs]
        self.ys = [y.values for y in ys]
        self.num_models = len(self.Xs)
        self.num_samples = sum((X.shape[0] for X in self.Xs))
        self.mu = mu
    
    def grad_fl(self, x):
        xs = np.broadcast_to(x, (self.num_models,x.shape[-1]))
            
        grads = []
        for X, x, y in zip(self.Xs, xs, self.ys):
            grad = (X.T @ (X @ x -y)) + (self.mu/2/self.num_models)*x
            grads.append(grad)
        return np.asarray(grads)
    
class ScaffNew:
    def __init__(self, x_0, x_sol, prox=None):
        _, _, self.X_train, self.y_train, self.X_test, self.y_test, self.X_test_ood, self.y_test_ood = load_data()

        self.x_0 = x_0.copy()
        self.x_sol = x_sol.copy()
        self.error_start = 0
        self.control_variates_start = 0
        self.lya_start = 0
        if prox:
            self.prox = prox
        else:
            self.prox = lambda x,y : x

    def solve(self, grad_est, num_models, d, mu, config, data_collection=True, comp_budget=None, seed=None):
        rng = np.random.default_rng(seed)

        p = config['p']
        comms = config['rounds']
        prox_loc = config['prox_loc']
        control_vars = config['optimizer'] == "ProxSkip"
        eval_every = config['eval_every']
        lr = config['lr']
        dual_lr = config['dual_lr']

        self.communication_rounds = 0
        h = np.zeros((num_models,d))
        x = self.x_0.copy()

        if config['prox_loc'] == 'local':
            prox = (lambda x: self.prox(x, lr))
        else:
            prox = (lambda x: self.prox(x, lr /p))

        if data_collection:
            x_sol = self.x_sol

            reg = Ridge(fit_intercept=False, alpha=mu/2, tol=1e-3)
            reg.fit(self.X_train, self.y_train)
            reg.coef = x_sol.copy()
            # x_sol = reg.coef_

            sparsity = int(norm(prox(x).round(15),0))
            loss = objectives.obj_fun(self.X_train, self.y_train, x, mu)
            
            self.error_start = (norm(x- x_sol)**2)
            
            h_opt = grad_est.grad_fl(np.broadcast_to(x_sol,(num_models, d)))
            self.control_variates_start = (norm(h - h_opt, 'fro')**2)
            self.lya_start = num_models*norm( x - x_sol)**2 + lr**2/p**2*norm(h - h_opt, 'fro')**2

            reg.coef_ = x
            train_score = reg.score(self.X_train, self.y_train)
            test_ood_score = reg.score(self.X_test_ood, self.y_test_ood)
            
            wandb.log({"error": 1, "loss":loss, "sparsity":sparsity, 
                               "control_variates": 1, "lyapunov":1, 
                               "train_score":train_score, "test_ood_score":test_ood_score},
                                 step=self.communication_rounds)

        local_step_counter = 0
        total_steps = 0 
        communication_cost = 0
        num_steps = rng.geometric(p) if config['optimizer'] == 'ProxSkip' else int(1/p)
        while (self.communication_rounds < comms) and (comp_budget <= 0 or total_steps < comp_budget):
            local_step_counter += 1
            total_steps += 1

            g = grad_est.grad_fl(x)
            x = x - lr*(g - h)
            if prox_loc == 'local':
                x = np.apply_along_axis(prox, 1, x)

            # Communicate and evaluate
            if local_step_counter >= num_steps:
                if prox_loc == 'comm_mod':
                    x = np.apply_along_axis(prox, 1, x)
                
                if control_vars:
                    hat_x = x.copy()

                if prox_loc == 'comm':
                    x = np.apply_along_axis(prox, 1, x)
                
                density_func = lambda x: int(norm(x.round(15),0)) / self.X_train.shape[1]
                communication_density = np.apply_along_axis(density_func, 1, x).mean()
                communication_cost += communication_density
                
                x = x.mean(0)

                if prox_loc == 'global':
                    x= prox(x)
                
                if control_vars:
                    h += p/lr*dual_lr*(x - hat_x)
                
                if prox_loc == 'global_mod':
                    x= prox(x)
                
                self.communication_rounds += 1
                local_step_counter = 0
                num_steps = rng.geometric(p) if config['optimizer'] == 'ProxSkip' else int(1/p)
                

                if data_collection and self.communication_rounds % eval_every == 0:
                    x_eval = x
                    
                    sparsity = 1 - int(norm(x_eval.round(15),0)) / self.X_train.shape[1]
                    if config['prox_loc'] == 'final':
                        x_eval = prox(x_eval)
                    
                    loss = objectives.obj_fun(self.X_train, self.y_train, x_eval, mu)

                    
                    distance = norm( x_eval - x_sol)**2/ self.error_start
                    control_variates = (norm(h-h_opt, 'fro')**2) / self.control_variates_start
                    lya = (num_models*norm( x_eval - x_sol)**2 + lr**2/p**2*norm(h - h_opt, 'fro')**2)/ self.lya_start

                    reg.coef_ = x_eval
                    train_score = reg.score(self.X_train, self.y_train)
                    test_ood_score = reg.score(self.X_test_ood, self.y_test_ood)
                
                    wandb.log({"error": distance, "loss":loss, "sparsity":sparsity, 
                               "control_variates": control_variates, "lyapunov":lya,
                               "train_score":train_score, 
                               "communication_cost": communication_cost,
                               "test_ood_score":test_ood_score}, step=self.communication_rounds, commit=False)
                    print(f"{self.communication_rounds}: dist {distance:.3g}, loss {loss:.3g}, sparsity {sparsity:.2f}, train {train_score:.3f}, test {test_ood_score:.3f}, lya {lya:.3g}, cost {communication_cost:.2f}")
                    if loss > 1e10:
                        return None

        return x

def load_data():
    df = pd.read_csv(DATA_PATH + "blogData_train.csv", header=None)
    csv_files = glob.glob(os.path.join(DATA_PATH, "blogData_test*.csv"))
    dfs_ood = []
    for path in csv_files:
        dfs_ood.append(pd.read_csv(path, header=None))
    df_ood = pd.concat(dfs_ood)

    df_train, df_y_train = df.iloc[:,:280], df[280]
    df_test_ood, df_y_test_ood = df_ood.iloc[:, :280], df_ood[280]

    df_train_fl = df_train.copy()

    source_scaler = MaxAbsScaler()
    df_train_fl[280] = df_y_train
    df_train_fl.iloc[:,:280] = source_scaler.fit_transform(df_train_fl.iloc[:,:280])


    groups = df_train_fl.groupby([i for i in range(50)])
    dfs = [group for _,group in groups]
    print(f"Number of clients: {len(dfs)}")
    Xs = [df.iloc[:,:280] for df in dfs]
    ys = [df[280] for df in dfs]

    # Centralized dataset
    X_train = pd.concat(Xs)
    y_train = pd.concat(ys)

    # X_test = source_scaler.transform(df_test)
    # y_test = df_y_test.values
    X_test, y_test = None, None
    X_test_ood = source_scaler.transform(df_test_ood)
    y_test_ood = df_y_test_ood.values

    return Xs, ys, X_train, y_train, X_test, y_test, X_test_ood, y_test_ood

def train_regression(config):

    wandb.init(config=config, project=config['project'] if config else None, allow_val_change=True)
    config = dict(wandb.config.items())

    # repeat_exp = {"FedAvg":
    #               {
    #                   "global": (0.00010, 0.051),
    #                   "local": (0.00010, 0.018),
    #                   "comm_mod": (0.00010, 0.018),
    #               },
    #               "ProxSkip": 
    #               {
    #                   "global": (0.000078, 0.0069),
    #                   "local": (0.000086, 0.0439),
    #                   "comm_mod" : (0.000067, 0.023)
    #               }
    #               }

    # config['lr'], config['p'] = repeat_exp[config['optimizer']][config['prox_loc']]
    # print(f"lr and p overwritten to {config['lr']} and {config['p']}.")

    # repeat_exp = {"TopK": 'final',
    #                "l1": 'local'
    #                }

    # config['prox_loc'] = repeat_exp[config['regularization']]
    # print(f"prox_loc overwritten to {config['prox_loc']}.")


    Xs, ys, X_train, y_train,  X_test, y_test, X_test_ood, y_test_ood = load_data()
    
    d = X_train.shape[1]
    num_models = len(Xs)

    alpha = config['alpha']
    if config['regularization'] == 'TopK':
        reg = Ridge(fit_intercept=False, alpha=alpha, tol=1e-3,)
    else:
        if alpha != 1e3:
            raise ValueError("l1 regularisation only done for alpha 1e3")
        reg_coeff_dict = {0 : 0, 0.8:9e3, 0.9:3.2e4, 0.95:9.5e4}
        reg_coeff = reg_coeff_dict[config['target_sparsity']]
        num_samples = X_train.shape[0]

        a = reg_coeff/2/num_samples
        b = alpha/num_samples
        alpha_E =  a + b
        l1_ratio = a/(a+b)
        reg = ElasticNet(fit_intercept=False, alpha=alpha_E, l1_ratio=l1_ratio, tol=1e-10)
    reg.fit(X_train, y_train)
    x_sol = reg.coef_
    print(f"Training score: {reg.score(X_train, y_train):.3g}")
    print(f"OOD Score: {reg.score(X_test_ood, y_test_ood):.4g}")
    print(f"Sparsity x_sol: {1- norm(x_sol.round(15),0)/d:.3g}")

    mu = 2 * alpha
    print(f"Train Loss: {objectives.obj_fun(X_train, y_train, x_sol, mu):.4g}")

    grad_est = BlogPostFederated(Xs, ys, mu)
    
    x_0 = np.zeros_like(x_sol)

    if config['target_sparsity'] == 0:
        prox = None
    else:
        k =  int((1 - config['target_sparsity'])*d)
        k = max(k, 0)
        k = min(d, k)
        prox = lambda x, y: objectives.topK_prox(x,y,k)

    if config['regularization'] == 'l1':
        prox = lambda x, y: L1Norm().prox(x, reg_coeff*y/num_models/2)

    proxskip = ScaffNew(x_0, x_sol, prox)
    x = proxskip.solve(grad_est, num_models, d, mu, config, comp_budget=config['comp_budget'], seed= config['seed'] if 'seed' in config else None)

    if x is not None:
        print("Run succesful.")
        wandb.log({"done":True})
    else:
        print("Run failed.")
        wandb.log({"done":False})
    wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Regression experiments")
    parser.add_argument('-t','--target_sparsity', help="Manual Mode: Provide a target sparsity",type=float, default=0)
    parser.add_argument('-s','--sweepID',help="Run a wandb sweep. Sweep ID needs to be provided.")
    parser.add_argument('--project', help="Provide a target sparsity", default="Debug")
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--dual_lr', type=float, default=1)
    parser.add_argument('--alpha', type=float, default=1e3)
    parser.add_argument('-p', type=float, default=1)
    parser.add_argument('--rounds', type=int, default=10)
    parser.add_argument('--comp_budget', type=int, default=-1, help="0 means no comp budget. otherwise the number of steps is limited.")
    parser.add_argument('--eval_every', type=int, default=1)
    parser.add_argument('--regularization', type=str, default='TopK', choices=['TopK','l1'])
    parser.add_argument('--prox_loc', type=str, default="global", choices=['global', "local", "comm", 'global_mod', 'comm_mod', 'final'])
    parser.add_argument('--optimizer', type=str, default="ProxSkip", choices=["ProxSkip","FedAvg"])
    args = parser.parse_args()

    if args.sweepID:
        wandb.agent(args.sweepID, partial(train_regression, None), project=args.project)
    else:
        train_regression(vars(args))
        