import argparse
import logging
import os
import pickle
import random
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from dataset import get_datase
from util import gaussian_noise, prob_round_tensor, skellam_noise
import torch

import warnings

warnings.filterwarnings("ignore")


def setup_log(name: str) -> logging.Logger:
    # """Generate the logger for the current run.
    # Args:
    #     name (str): Logging file name.

    # Returns:
    #     logging.Logger: Logger object for the current run.
    # """
    my_logger = logging.getLogger(name)
    my_logger.setLevel(logging.INFO)
    log_format = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s")
    filename = f"log_{name}.log"
    log_handler = logging.FileHandler(f"log/{filename}", mode="w")
    log_handler.setLevel(logging.INFO)
    log_handler.setFormatter(log_format)
    my_logger.addHandler(log_handler)
    return my_logger


def PCA(X, num_components):
    X_meaned = X - torch.mean(X, 0, keepdim=True)

    # Calculate covariance matrix
    cov_mat = (X_meaned.t().mm(X_meaned)) / (X_meaned.shape[0] - 1)  # ??
    # print(torch.mean(cov_mat[0][0]))
    # Step-3
    eigen_values, eigen_vectors = torch.linalg.eigh(cov_mat)
    # print(eigen_values.shape, eigen_vectors.shape)
    # Step-4
    sorted_index = torch.argsort(eigen_values, descending=True)
    sorted_eigenvectors = eigen_vectors[:, sorted_index]

    # Step-5
    eigenvector_subset = sorted_eigenvectors[:, 0:num_components]

    # Step-6
    # print(eigenvector_subset.shape, X_meaned.shape)
    X_reduced = torch.matmul(eigenvector_subset.t(), X_meaned.t()).t()

    return X_reduced, eigenvector_subset


def SVD(X):
    X_meaned = X - torch.mean(X, 0, keepdim=True)

    # Calculate covariance matrix
    cov_mat = (X_meaned.t().mm(X_meaned)) / (X_meaned.shape[0] - 1)  # ??
    # print(torch.mean(cov_mat[0][0]))
    # Step-3
    eigen_values, eigen_vectors = torch.linalg.eigh(cov_mat)
    # print(eigen_values.shape, eigen_vectors.shape)
    # Step-4
    sorted_index = torch.argsort(eigen_values, descending=True)
    sorted_eigenvectors = eigen_vectors[:, sorted_index]

    return sorted_eigenvectors


# 0.25       13.285
# 0.5        7.032
# 1          3.731
# 2          1.994
# 4          1.082
# 8          0.600
# 16         0.344
# 32         0.205
# 64         0.127
eps_noise_mapping_local = {
    13.285: 0.25,
    7.032: 0.5,
    3.731: 1,
    1.994: 2,
    1.082: 4,
    0.600: 8,
    0.344: 16,
    0.205: 32,
    0.127: 64,
}


# KDDCUP (d=117) with 
# eps varies from 0.25, 1, 2, 4, 8:
# ----------------------------------------
# b        mu
# 4       225.934  18.426   6.145   2.613   1.377
# 6       113.018   9.865   3.727   1.802   1.028
# 8       107.256   9.43    3.599   1.756   1.008
# 10      106.901   9.403   3.592   1.753   1.006
# 12      106.879   9.401   3.591   1.753   1.006
# ----------------------------------------
# ----------------------------------------

# ACSIncome (d=816) with 
# eps varies from 0.25, 1, 2, 4, 8:
# -----------------------------------------
# b        mu
# 4      1861.136  144.543   41.986   13.651    5.54
# 6      153.702  13.849   5.544   2.828   1.671
# 8      109.986  10.592   4.555   2.442   1.482
# 10     107.501  10.408   4.497   2.419   1.47
# 12     107.346  10.396   4.493   2.417   1.469
# ----------------------------------------
# ----------------------------------------

# gen (d=20531) with 
# eps varies from 4, 8, 16, 32, 64:
# -----------------------------------------
# b        mu
# 8        5.589 3.52  2.331 1.585 1.095
# 10       4.653 3.003 2.014 1.38  0.957
# 12       4.595 2.97  1.993 1.367 0.948
# 14       4.591 2.968 1.992 1.366 0.948

# ----------------------------------------
# ----------------------------------------

# citeseer (d=3703) with 
# eps varies from 4, 8, 16, 32, 64:
# -----------------------------------------
# b        mu
# 8      3.361 2.099 1.381 0.936 0.645
# 10     3.231 2.029 1.34  0.91  0.628
# 12     3.223 2.025 1.337 0.908 0.626
# 14     3.222 2.025 1.337 0.908 0.626
# ----------------------------------------
# ----------------------------------------


def eps_fl_mapping_local(mu, b, dataset):
    if dataset == "gen":
        if b == 12:
            mu_list = [13857, 4209, 1385, 473, 193]
        elif b == 14:
            mu_list = [13892, 4220, 1389, 475, 193]
        elif b == 16:
            mu_list = [13932, 4232, 1393, 476, 194]
        elif b == 18:
            mu_list = [13978, 4246, 1397, 477, 194]
        elif b == 20:
            mu_list = [14029, 4261, 1402, 479, 195]
        elif b == 8:
            mu_list = [13803, 4193, 1380, 472, 192]
        elif b == 10:
            mu_list = [13828, 4200, 1382, 472, 192]

    elif dataset == "citeseer":
        if b == 4:
            mu_list = [2493, 757, 250, 86, 35]
        elif b == 6:
            mu_list = [2506, 762, 251, 86, 35]
        elif b == 8:
            mu_list = [2525, 767, 253, 87, 36]
        elif b == 10:
            mu_list = [2549, 775, 255, 88, 36]
        elif b == 12:
            mu_list = [2579, 784, 258, 89, 36]
        elif b == 14:
            mu_list = [2613, 794, 262, 90, 37]

    mu_index = mu_list.index(mu)
    eps_list = [4, 8, 16, 32, 64]
    return eps_list[mu_index]


if __name__ == "__main__":
    """Run the main function
    Example run:
    python main.py --dataset iris --clipping_norm 1 --b 6 --random_seed 1234 --setting fl --skellam_mu 1.0 --num_components 2
    python main.py --dataset iris --clipping_norm 1 --b 6 --random_seed 1234 --setting centralized --skellam_mu 1.0 --num_components 2
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", help="dataset name", type=str, default="iris")
    parser.add_argument(
        "--clipping_norm", help="clipping norm for each row", type=float, default=2
    )
    parser.add_argument("--b", help="b for computing gamma", type=float, default=6)
    parser.add_argument("--random_seed", help="random seed", type=int, default=1234)
    parser.add_argument(
        "--sigma", help="sigma for the gaussian distribution", type=float, default=10000
    )
    parser.add_argument(
        "--setting",
        help="setting for the run, fl or centralized",
        type=str,
        default="fl",
    )
    parser.add_argument(
        "--skellam_mu", help="mu for skellam_noise", type=float, default=7
    )
    parser.add_argument(
        "--num_components", help="number of components", type=int, default=2
    )
    parser.add_argument("--gpu_id", help="gpu id", type=int, default=0)
    args = parser.parse_args()

    log_file = "_".join(
        [
            args.setting,
            args.dataset,
            str(args.random_seed),
            str(args.skellam_mu),
            str(args.sigma),
            str(args.b),
            str(args.clipping_norm),
        ]
    )

    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)

    logger = setup_log(log_file)

    IS_PLOT = False
    IS_ANALYSIS = True
    log_dir = "results"
    logger.info(args)
    device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"

    baseline_time = time.time()
    # 1. Preprocess the data: include the normalization over column and clipping over rows

    X, y = get_datase(args.dataset, args.clipping_norm)
    # Convert to tensor on gpu
    X = torch.tensor(X, device=device)
    y = torch.tensor(y, device=device)

    logger.info("====>Load data time: %s", time.time() - baseline_time)
    baseline_time = time.time()
    file_name = "_".join(
        [
            args.setting,
            args.dataset,
            str(X.shape[0]),
            str(args.random_seed),
            str(args.skellam_mu),
            str(args.sigma),
            str(args.b),
            str(args.clipping_norm),
        ]
    )

    if os.path.exists(f"{log_dir}/{file_name}.pkl") is False:
        if args.setting == "fl":
            # 2. Discretize the data
            gamma = 2**args.b
            prepare_x = X * gamma
            assert X[10, 2] * gamma == prepare_x[10, 2], "gamma multiplication error"
            print(prepare_x[10, 2], args.b, X[10, 2] * gamma)
            discretized_x = prob_round_tensor(
                prepare_x, device
            )  # prob_round(prepare_x)
            discretized_x = discretized_x / gamma  # scale back
            logger.info("Discretize data time: %s", time.time() - baseline_time)

            # 3. add noise
            # trunk-ignore(git-diff-check/error)
            baseline_time = time.time()
            clean_x = torch.matmul(discretized_x.t(), discretized_x)
            assert clean_x.shape == (
                X.shape[1],
                X.shape[1],
            ), "multiplication error"
            # tri_upper_diag_x = np.triu(multiplied_x, k=0)
            tri_upper_diag_x = torch.triu(clean_x, diagonal=0)
            noise_size = tri_upper_diag_x.shape
            mu = args.skellam_mu  # for skellam value
            noise = skellam_noise(noise_size, mu, device)
            logger.info(f"mu:{mu}, skellam_mu:{args.skellam_mu}")
            tri_upper_diag_x = torch.triu(tri_upper_diag_x + noise, diagonal=0)
            noisy_cov = (
                tri_upper_diag_x
                + tri_upper_diag_x.t()  # Transposing the tensor
                - torch.diag(torch.diag(tri_upper_diag_x))  # Subtracting the diagonal
            )
            logger.info(
                "----->l2 distance between the noisy x and true x : %.5f",
                torch.norm(noisy_cov - torch.mm(X.t(), X)),
            )
            logger.info("====>Add noise time: %s", time.time() - baseline_time)
            tri_upper_diag_x.to("cpu")

        elif args.setting == "centralized":
            baseline_time = time.time()
            multiplied_x = torch.matmul(X.t(), X)
            tri_upper_diag_x = torch.triu(multiplied_x, diagonal=0)
            noise_size = tri_upper_diag_x.shape
            noise = gaussian_noise(noise_size, args.sigma, device)
            tri_upper_diag_x = torch.triu(tri_upper_diag_x + noise, diagonal=0)
            noisy_cov = (
                tri_upper_diag_x
                + tri_upper_diag_x.t()
                - torch.diag_embed(torch.diag(tri_upper_diag_x))
            )
            logger.info(
                "----->l2 distance between the noisy x and true x : %.5f",
                torch.norm(noisy_cov - torch.mm(X.t(), X)),
            )
            logger.info("====>Add noise time: %s", time.time() - baseline_time)
            tri_upper_diag_x.to("cpu")

        elif args.setting == "local":
            # noise_size = X.shape
            # noise = guassian_noise(noise_size, args.sigma)
            # noisy_x = X + noise
            # noisy_cov = np.dot(noisy_x.T, noisy_x)
            baseline_time = time.time()
            noise_size = X.shape
            noise = gaussian_noise(noise_size, args.sigma, device)
            noisy_x = X + noise
            noisy_cov = torch.matmul(noisy_x.t(), noisy_x)
            logger.info(
                "----->l2 distance between the noisy x and true x : %.5f",
                torch.norm(noisy_cov - torch.mm(X.t(), X)),
            )
            logger.info("====>Add noise time: %s", time.time() - baseline_time)

        # 4. SVD for noisy_cov
        baseline_time = time.time()
        logger.info(noisy_cov.shape)
        eigen_values, eigen_vectors = torch.linalg.eigh(noisy_cov)
        sorted_index = torch.argsort(eigen_values, descending=True)
        sorted_eigenvalue = eigen_values[sorted_index]
        sorted_eigenvectors = eigen_vectors[:, sorted_index]

        noisy_cov.to("cpu")
        noise.to("cpu")
        sorted_eigenvalue.to("cpu")
        sorted_eigenvectors.to("cpu")

        # with open(f"{log_dir}/{file_name}.pkl", "wb") as f:
        #     pickle.dump(
        #         {"eigenvalue": sorted_eigenvalue, "eigenvector": sorted_eigenvectors}, f
        #     )

    else:
        if IS_ANALYSIS:
            with open(f"{log_dir}/{file_name}.pkl", "rb") as f:
                saved_results = pickle.load(f)
                # trunk-ignore(git-diff-check/error)
                sorted_eigenvalue = saved_results["eigenvalue"].to(device)
                sorted_eigenvectors = saved_results["eigenvector"].to(device)

    if IS_ANALYSIS:
        clean_sorted_eigenvectors = SVD(X)

        for num_components in [100, 200, 400, 800, 1600, 3200]:
            logger.info("====>num_components: %s", num_components)
            noisy_eigenvector_subset = sorted_eigenvectors[:, 0:num_components]
            noisy_eigenvector_subset.to(device)
            print("noisy shape", noisy_eigenvector_subset.shape)
            logger.info("====>SVD time: %s", time.time() - baseline_time)

            baseline_time = time.time()
            clean_eigenvector_subset = clean_sorted_eigenvectors[
                :, 0:num_components
            ]  ##PCA(X, num_components=num_components)
            A = torch.matmul(X.t(), X) / X.shape[0]
            # λ = eigenvector^T * (X^T * X) * eigenvector
            q_f = (
                torch.trace(
                    torch.matmul(
                        torch.matmul(noisy_eigenvector_subset.t(), A),
                        noisy_eigenvector_subset,
                    )
                )
                .cpu()
                .item()
            )
            q_f_clean = (
                torch.trace(
                    torch.matmul(
                        torch.matmul(clean_eigenvector_subset.t(), A),
                        clean_eigenvector_subset,
                    )
                )
                .cpu()
                .item()
            )
            logger.info("q_f:  %.8f", q_f)
            logger.info("q_clean: %.8f", q_f_clean)

            results = {
                "setting": [args.setting],  # if args.setting != "fl" else ["fl_new"],
                "dataset": [args.dataset],
                "num_samples": [X.shape[0]],
                "num_components": [num_components],
                "features": [X.shape[1]],
                "random_seed": [args.random_seed],
                "skellam_mu": [args.skellam_mu],
                "sigma": [args.sigma],
                "eps": eps_noise_mapping_local[args.sigma]
                if args.setting != "fl"
                else eps_fl_mapping_local(
                    args.skellam_mu, args.b, args.dataset
                ),  # eps_fl_mapping_local[args.skellam_mu],
                "b": [args.b],
                "clipping_norm": [args.clipping_norm],
                "q_f": [q_f],
                "q_clean": [q_f_clean],
            }
            df = pd.DataFrame(results)
            df.to_csv("pca_results.csv", mode="a", index=False, header=False)
            logger.info("====>Measure time: %s", time.time() - baseline_time)

        if IS_PLOT:
            X_reduced = torch.matmul(noisy_eigenvector_subset.t(), X.t()).t()
            principal_df = pd.DataFrame(X_reduced.cpu().numpy(), columns=["PC1", "PC2"])
            principal_df = pd.concat(
                [principal_df, pd.DataFrame(y)], axis=1
            )  # Assuming y is a numpy array or pandas DataFrame
            plt.figure(figsize=(6, 6))
            sns.scatterplot(
                data=principal_df,
                x="PC1",
                y="PC2",
                hue="target",
                s=60,
                palette="icefire",
            )

            # 8. also show the result wihtout noise
            clean_principal_df = pd.DataFrame(
                PCA(X, num_components)[0].cpu().numpy(), columns=["PC1", "PC2"]
            )  # Assuming PCA function is updated to use PyTorch
            clean_principal_df = pd.concat(
                [clean_principal_df, pd.DataFrame(y)], axis=1
            )
            sns.scatterplot(
                data=clean_principal_df,
                x="PC1",
                y="PC2",
                hue="target",
                s=60,
                marker="+",
            )
            plt.show()
