import yaml
import json
import argparse
import logging
import pickle
from datetime import datetime
import pickle
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from scipy.stats import ortho_group

from pathlib import Path
# from src.tasks import TASKS
from src.compressors_l_infty import COMPRESSORS
from src.utils import seed_all

import numpy as np
## Set logging level to lowest
logging.basicConfig(level=logging.ERROR)

CONFIG_PATH=Path("configs")
RESULTS_PATH=Path("results")
TENSORBOARD_PATH=Path("tensorboard")

def load_configs(args, dataset_config):
    
    compressor_config = {"name": args.compressor}
    with open(CONFIG_PATH / Path(f"compressors.yaml")) as f:
        all_compressors_config = yaml.safe_load(f)
        compressor_config["params"] = all_compressors_config[args.compressor]

    config = {"task": {}, "dataset": dataset_config, "compressor": compressor_config}
    return config

def create_mean_estimation_datasets(name, config):
    if name == "gaussian_err":
        true_mean = np.random.normal(0,1, size=(config["d"]))
        true_mean = config["B"]*true_mean/np.linalg.norm(true_mean)
        client_data = {i : true_mean + np.random.normal(0,1,size=(config["d"]))*config["het"] + np.random.normal(0, config["noise"], size=(config["n"], config["d"]))
                       for i in range(config["m"])}
    if name == "gaussian":
        ortho_matrix = ortho_group.rvs(dim=config["d"])
        client_data = {i:   config["het"]* (ortho_matrix @ np.random.binomial(1, 0.5, size=(config["d"])))/np.sqrt(config["d"])
                       + np.random.normal(0,config["noise"], size=(config["n"], config["d"])) for i in range(config["m"])}
    if name == "uniform":
        coordinate_means = config["B"]*np.random.uniform(0,1,size=(config["d"]))
        client_data = {i : coordinate_means + 4*config["het"]*np.random.uniform(-1, 1, size=(config["d"]))
                       +  4*config["noise"]*np.random.uniform(-1,1, size=(config["n"], config["d"])) for i in range(config["m"])}
        
    if name == "unit_vector":
        ortho_matrix = ortho_group.rvs(dim=config["d"])
        true_vector_idx = np.random.choice(range(config["d"]))
        true_vector = ortho_matrix[:,true_vector_idx]

        other_vectors = ortho_matrix[:, np.arange(config["d"])!=true_vector_idx ]
        client_data = {}
        for i in range(config["m"]):
            coeffs = np.random.normal(0,1,size=(config["d"]-1))
            coeffs = np.sqrt(config["het"])*coeffs/np.linalg.norm(coeffs)
            noise_vars = np.random.normal(0,1, size=(config["n"], config["d"]))
            noise_vars /= np.linalg.norm(noise_vars, axis=1).reshape(-1,1)
            client_mean = (1 - np.sqrt(config["het"]) - np.sqrt(config["noise"]))*true_vector + other_vectors @coeffs

            client_data[i] = (1 - np.sqrt(config["noise"]))*client_mean + np.sqrt(config["noise"])*noise_vars

    return client_data



def set_compressor_configs_d_512(compressor_config : dict, m:int, d:int, seed:int, num_reps:dict = None):
    compressor_kwargs = {"m": m, "d": d} 
    compressor_name = compressor_config["name"]
    if compressor_config["params"] is not None:
        if len(compressor_config["params"]) > 0:
            compressor_kwargs = {**compressor_kwargs, **compressor_config["params"]}

    ## Max compressor bits obtained for PermK.
    if compressor_name in ['randk', 'topk', 'randkspatial', 'randkproj']:
#            compressor_kwargs["K"] = np.ceil(compressor_kwargs["d"]/compressor_kwargs["m"]).astype(int)
        compressor_kwargs["K"] = 58
    elif compressor_name =="induced":
        compressor_kwargs["K"] = 29
    elif compressor_name == 'rotatedquant':
        compressor_kwargs["levels"] = 25
    elif compressor_name == "rotatedcorrelatedquant":
        compressor_kwargs["levels"] = 23
    elif compressor_name == "sparsereg":
        compressor_kwargs["K"] = 200
    elif compressor_name == "noisy_sign":
        compressor_kwargs["sigma"] = 1000.0

    if num_reps is not None and compressor_name in ["noisy_sign", "hadamard", "onebitavg", "sparsereg", "permk", "drive"]:
        compressor_kwargs["num_reps"] = num_reps[compressor_name]
        #     compressor_kwargs["num_reps"] *= 5
        #     compressor_kwargs["num_reps"] = 20
    if compressor_name in ["hadamard", "sparsereg"]:
        compressor_kwargs["seed"] = seed
    compressor = COMPRESSORS[compressor_name](**compressor_kwargs)
    return compressor

from argparse import Namespace
compressor_unit_norm = ["sparsereg", "onebitavg", "drive", "randk", "randkproj", "rotatedquant"]
seeds = [1,2,3]
seed = seeds[0]
dataset_config = {"name":"gaussian", "params":{"m": 100, "n":100, "het": 10.0, "noise": 0.01, "d": 512, "B": 100.0}}
compressor_results = {}
compressor_bits = {}
compressor_dict = {}
num_reps = {"sparsereg": 305, "drive": 4, "permk": 15, "hadamard": 5, "noisy_sign": 5, "onebitavg": 72}
for compressor_name in compressor_unit_norm:
        args = Namespace(task="mean_estimation", dataset = dataset_config["name"], seed=seed, compressor=compressor_name)
        config=load_configs(args, dataset_config=dataset_config)
        ## Initialize compressor
        # This handles the case of empty kwargs
        compressor = set_compressor_configs_d_512(compressor_config=config["compressor"], 
                                                  m=config["dataset"]["params"]["m"], 
                                                  d=config["dataset"]["params"]["d"],
                                                  num_reps = num_reps,
                                                  seed=seed)

        compressor_bits[compressor_name] = compressor.num_bits_float()
        compressor_dict[compressor_name] = compressor
        
        
compressor_bit_values =  np.floor(np.array(list(compressor_bits.values()))/100)*100
num_reps_new = dict(zip(compressor_unit_norm,list(np.array(compressor_bit_values.max()/compressor_bit_values))))
# print(num_reps_new)


results = {}
# het_levels = 2**np.linspace(-3,-2,3)
het_levels = [0.01, 0.1, 0.2, 0.3, 0.4]
# het_levels = [0.1, 1.0, 10.0, 0.01, 0.001, 100.0]
for het in tqdm(het_levels):
#    dataset_config = {"name":"gaussian", "params":{"m": 100, "n":100, "het": het, "noise": 0.01, "d": 512}}
    # dataset_config = {"name":"uniform", "params":{"m": 100, "n":100, "het": het, "noise": 0.01, "d": 512, "B": 100.0}}
    dataset_config = {"name":"unit_vector", "params":{"m": 100, "n":100, "het": het, "noise": 0.01, "d": 512, "B": 100.0}}

    results[het] = {}
    for seed in tqdm(seeds):
        ## Seed experiments
        seed_all(args.seed)

        # Create client datasets
        all_client_data = create_mean_estimation_datasets(name=dataset_config["name"], config=dataset_config["params"])
        logging.info("Generated all clients")
        for compressor_name in tqdm(compressor_unit_norm):
            if compressor_name not in results[het].keys():
                results[het][compressor_name] = {"mean":[], "l2_error": [], "true_mean": [], "comm_bits": [], "l_infty_error":[]}
                
            args = Namespace(task="mean_estimation", dataset = dataset_config["name"], seed=seed, compressor=compressor_name)

            # start_time = datetime.now()
            # start_time_string =  start_time.strftime("%Y-%m-%d_%H:%M:%S")
            # logging.info(f"Task : {args.task}, Dataset : {args.dataset}, Compressor: {args.compressor}, Seed : {args.seed}, Start Time : {start_time_string}")


            # ## Make results dir
            # results_dir = RESULTS_PATH / Path(f"{args.task}/{args.dataset}/seed_{args.seed}/{args.compressor}/start_time_{start_time_string}")
            # results_dir.mkdir(parents=True, exist_ok=True)

            # ## Make experiment name for tensorboard
            # experiment_name = f"task_{args.task}_dataset_{args.dataset}_compressor_{args.compressor}_seed_{args.seed}_flag_{args.flag}_start_time_{start_time_string}"
            # tb_path = TENSORBOARD_PATH / Path(experiment_name)





            # Get config
            config=load_configs(args, dataset_config=dataset_config)
            logging.info("Config loaded")


            ## Initialize compressor
            logging.info("Compressor initialized")
            compressor = set_compressor_configs_d_512(compressor_config=config["compressor"], 
                                                    m=config["dataset"]["params"]["m"], 
                                                    d=config["dataset"]["params"]["d"],
                                                    num_reps = num_reps,
                                                    seed=seed)


            logging.info("Starting task")

            ## Run task 
            
            client_means = {i: all_client_data[i].mean(axis=0) for i in range(len(all_client_data))}
            client_arr = np.stack(list(client_means.values()), axis=0)
            true_mean = client_arr.mean(axis=0)
            decoded_mean = compressor.compress(client_arr)
            comm_bits = compressor.num_bits_float()
            l2_error = np.linalg.norm(true_mean - decoded_mean)**2
            l_infty_error = np.max(np.abs(true_mean - decoded_mean))
            results[het][compressor_name]["mean"].append(decoded_mean)
            results[het][compressor_name]["l2_error"].append(l2_error)
            results[het][compressor_name]["true_mean"].append(true_mean)
            results[het][compressor_name]["comm_bits"].append(comm_bits)
            results[het][compressor_name]["l_infty_error"].append(l_infty_error)

    ### Average errors
    for compressor_name in compressor_unit_norm:
        for key in ["l2_error", "comm_bits", "l_infty_error"]:
            results[het][compressor_name][f"mean_{key}"] = sum(results[het][compressor_name][key])/len(seeds)    

results_folder = Path("results/mean_est/unit_norm")
results_folder.mkdir(parents=True, exist_ok=True)
with open(results_folder/ "cosine_distance.pickle", "wb") as f:
    pickle.dump(results, f)