import json
import logging
import os
from typing import Callable, Dict, Iterable, Optional, Union  # noqa: F401

import lightning as L
from lightning.fabric.utilities.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf
from torch import nn
from tqdm.auto import tqdm
import numpy as np
import fusion_bench.utils.instantiate
from fusion_bench.method import BaseAlgorithm
from fusion_bench.mixins import LightningFabricMixin
from fusion_bench.modelpool import BaseModelPool
from fusion_bench.programs import BaseHydraProgram
from fusion_bench.taskpool import BaseTaskPool
from fusion_bench.utils import import_object, instantiate, timeit_context
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
from fusion_bench.utils.json import print_json
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
import copy
log = logging.getLogger(__name__)
import random
import torch
class FabricModelFusionProgram(
    LightningFabricMixin,
    BaseHydraProgram,
):
    method: BaseAlgorithm
    modelpool: BaseModelPool
    taskpool: Optional[BaseTaskPool] = None

    _config_mapping = BaseHydraProgram._config_mapping | {
        "_method": "method",
        "_modelpool": "modelpool",
        "_taskpool": "taskpool",
        "_fabric": "fabric",
        "fast_dev_run": "fast_dev_run",
        "seed": "seed",
    }

    def __init__(
        self,
        method: DictConfig,
        modelpool: DictConfig,
        taskpool: Optional[DictConfig] = None,
        *,
        fabric: Optional[DictConfig] = None,
        print_config: bool = True,
        dry_run: bool = False,
        report_save_path: Optional[str] = None,
        merged_model_save_path: Optional[str] = None,
        merged_model_save_kwargs: Optional[DictConfig] = None,
        fast_dev_run: bool = False,
        seed: Optional[int] = None,
        print_function_call: bool = True,
        **kwargs,
    ):
        self._method = method
        self._modelpool = modelpool
        self._taskpool = taskpool
        self._fabric = fabric
        self.report_save_path = report_save_path
        self.merged_model_save_path = merged_model_save_path
        self.merged_model_save_kwargs = merged_model_save_kwargs
        self.fast_dev_run = fast_dev_run
        self.seed = seed
        super().__init__(**kwargs)
        fusion_bench.utils.instantiate.PRINT_FUNCTION_CALL = print_function_call

        if print_config:
            print_config_tree(
                self.config,
                print_order=["method", "modelpool", "taskpool"],
            )
        if dry_run:
            log.info("The program is running in dry-run mode. Exiting.")
            exit(0)

    def _instantiate_and_setup(
        self, config: DictConfig, compat_load_fn: Optional[str] = None
    ):
        R"""
        Instantiates and sets up an object based on the provided configuration.

        This method performs the following steps:
        1. Checks if the configuration dictionary contains the key "_target_".
        2. If "_target_" is not found (for v0.1.x), attempts to instantiate the object using a compatible load function if provided.
           - Logs a warning if "_target_" is missing.
           - If `compat_load_fn` is provided, imports the function and uses it to instantiate the object.
           - If `compat_load_fn` is not provided, raises a ValueError.
        3. If "_target_" is found (for v.0.2.0 and above), attempts to import and instantiate the object using the `instantiate` function.
           - Ensures the target can be imported.
           - Uses the `instantiate` function with `_recursive_` set based on the configuration.
        4. Sets the `_program` attribute of the instantiated object to `self` if the object has this attribute.
        5. Sets the `_fabric_instance` attribute of the instantiated object to `self.fabric` if the object has this attribute and `self.fabric` is not None.
        6. Returns the instantiated and set up object.
        """
        if "_target_" not in config:
            log.warning(
                "No '_target_' key found in config. Attempting to instantiate the object in a compatible way."
            )
            if compat_load_fn is not None:
                compat_load_fn = import_object(compat_load_fn)
                if rank_zero_only.rank == 0:
                    print_bordered(
                        OmegaConf.to_yaml(config),
                        title="instantiate compat object",
                        style="magenta",
                        code_style="yaml",
                    )
                obj = compat_load_fn(config)
            else:
                raise ValueError(
                    "No load function provided. Please provide a load function to instantiate the object."
                )
        else:
            # try to import the object from the target
            # this checks if the target is valid and can be imported
            import_object(config._target_)
            obj = instantiate(
                config,
                _recursive_=config.get("_recursive_", False),
            )
        if hasattr(obj, "_program"):
            obj._program = self
        if hasattr(obj, "_fabric_instance") and self.fabric is not None:
            obj._fabric_instance = self.fabric
        if hasattr(obj, "_fabric") and self.fabric is not None:
            # for v0.1.x compatibility
            obj._fabric = self.fabric
        return obj

    def save_merged_model(self, merged_model):
        """
        Saves the merged model to the specified path.
        """
        if self.merged_model_save_path is not None:
            # path to save the merged model, use "{log_dir}" to refer to the logger directory
            save_path: str = self.merged_model_save_path
            if "{log_dir}" in save_path and self.log_dir is not None:
                save_path = save_path.format(log_dir=self.log_dir)

            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)

            # save the merged model
            if self.merged_model_save_kwargs is not None:
                merged_model_save_kwargs = self.merged_model_save_kwargs
            else:
                merged_model_save_kwargs = {}
            with timeit_context(f"Saving the merged model to {save_path}"):
                self.modelpool.save_model(
                    merged_model,
                    save_path,
                    **merged_model_save_kwargs,
                )
        else:
            print("No save path specified for the merged model. Skipping saving.")

    def evaluate_merged_model(
        self,
        taskpool: BaseTaskPool,
        merged_model: Union[nn.Module, Dict, Iterable],
        *args,
        **kwargs,
    ):
        """
        Evaluates the merged model using the provided task pool.

        Depending on the type of the merged model, this function handles the evaluation differently:
        - If the merged model is an instance of `nn.Module`, it directly evaluates the model.
        - If the merged model is a dictionary, it extracts the model from the dictionary and evaluates it.
          The evaluation report is then updated with the remaining dictionary items.
        - If the merged model is an iterable, it recursively evaluates each model in the iterable.
        - Raises a `ValueError` if the merged model is of an invalid type.

        Args:
            taskpool: The task pool used for evaluating the merged model.
            merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
            *args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
            **kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.

        Returns:
            The evaluation report. The type of the report depends on the type of the merged model:
            - If the merged model is an instance of `nn.Module`, the report is a dictionary.
            - If the merged model is a dictionary, the report is a dictionary updated with the remaining dictionary items.
            - If the merged model is an iterable, the report is a list of evaluation reports.
        """
        if isinstance(merged_model, nn.Module):
            report = taskpool.evaluate_test(merged_model, *args, **kwargs)
            return report
        elif isinstance(merged_model, Dict):
            report = {}
            for key, item in merged_model.items():
                if isinstance(item, nn.Module):
                    report[key] = taskpool.evaluate(item, *args, **kwargs)
                else:
                    # metadata
                    report[key] = item
            return report
        elif isinstance(merged_model, Iterable):
            return [
                self.evaluate_merged_model(taskpool, m, *args, **kwargs)
                for m in tqdm(merged_model, desc="Evaluating models")
            ]
        else:
            raise ValueError(f"Invalid type for merged model: {type(merged_model)}")

    def evaluate_merged_model_ada(
        self,
        taskpool: BaseTaskPool,
        merged_model: Union[nn.Module, Dict, Iterable],
        *args,
        **kwargs,
    ):
        
        report = taskpool.evaluate_test_ada(merged_model, *args, **kwargs)
        return report

    def run(self):
        """
        Executes the model fusion program.
        """
        torch.manual_seed(42)
        random.seed(42)
        np.random.seed(42)
        fabric = self.fabric
        if self.seed is not None:
            L.seed_everything(self.seed)
        if fabric.global_rank == 0:
            self._link_hydra_output()

        log.info("Running the model fusion program.")
        # setup the modelpool, method, and taskpool
        log.info("loading model pool")
        self.modelpool = self._instantiate_and_setup(
            self._modelpool,
            compat_load_fn="fusion_bench.compat.modelpool.load_modelpool_from_config",
        )
        log.info("loading method")
        self.method = self._instantiate_and_setup(
            self._method,
            compat_load_fn="fusion_bench.compat.method.load_algorithm_from_config",
        )
        if self._taskpool is not None:
            log.info("loading task pool")
            self.taskpool = self._instantiate_and_setup(
                self._taskpool,
                compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
            )
        if 'name' in self.method.config:
            method_name = self.method.config.name
        else:
            method_name = self.method.config._target_
        print(method_name)
        if 'adamerging' in method_name:
            delta = 0.025
            lambd = 0.05
            n = 5
            modelpool_len = len(self.modelpool)-1
            validation_sets = self.taskpool.validation_sets()
            m = len(validation_sets[list(validation_sets.keys())[0]])
            print(m)
            merging_weights = self.method.run_pac(self.modelpool, validation_sets, delta, lambd, n)
            #run this for off-the-shelf bound evaluation
            #merging_weights = self.method.run(self.modelpool, validation_sets) 
            val_accuracy = {}
            test_accuracy = {}
            pac_bayes = {}

            for task_name in merging_weights.keys():
                weights_shape = merging_weights[task_name].size()
                num_weights = 1
                for dim in weights_shape:
                    num_weights *= dim
                modelpool_len = num_weights
                posterior_mu = merging_weights[task_name].view(-1).cpu().numpy()
                posterior_var = np.array([0.05] * modelpool_len)

                prior_mu = np.array([1 / weights_shape[0]] * modelpool_len)
                prior_var = np.array([lambd] * modelpool_len)

                b_re = B_RE_single(posterior_mu, prior_mu, posterior_var, lambd, m, delta)
                accuracy = 0

                val_accuracy[task_name] = 0
                test_accuracy[task_name] = 0
                for i in range(n):
                    weights = sample_from_gaussian(posterior_mu, posterior_var)
                    weights = weights.reshape(weights_shape)
                    merged_model_ = self.method.sampled_models(self.modelpool, weights, task_name)
                    results = self.taskpool.evaluate_validation_ada(merged_model_, task_name)
                    val_accuracy[task_name] += results[task_name]['accuracy']/n        
                    results = self.taskpool.evaluate_test_ada(merged_model_, task_name)
                    test_accuracy[task_name] += results[task_name]['accuracy']/n         
                task_dict = {}
                task_dict['test_error'] = 1.-test_accuracy[task_name]
                task_dict['val_error'] = 1.-val_accuracy[task_name]
                task_dict['pac_bayes'] = approximate_BPAC_bound(val_accuracy[task_name], b_re, niter=5)
                task_dict['b_re'] = b_re
                task_dict['upper_bound'] = 1.-val_accuracy[task_name] + np.sqrt(b_re/2)
                pac_bayes[task_name] = task_dict
                
            print(pac_bayes)
            if self.report_save_path is not None:
                # save report (Dict) to a file
                # if the directory of `save_report` does not exists, create it
                os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
                json.dump(pac_bayes, open(self.report_save_path, "w"))

        if 'TaskArithmetic' in method_name or 'TiesMerging' in method_name:
            _ = self.taskpool.validation_sets()
            task_names = self.taskpool.task_names
            weight_grids = [round(x, 2) for x in [i * 0.05 for i in range(1, 21)]]
            merging_weights = {}
            merged_models = {}
            delta = 0.025
            lambd = 0.05
            n = 5

            for task_name in task_names:
                best_acc = 1e8
                best_weight = None
                for weight in weight_grids:
                    modelpool_len = len(self.modelpool)-1
                    posterior_mu = np.array([weight])
                    posterior_var = np.array([0.05])

                    prior_mu = np.array([1 / modelpool_len])
                    prior_var = np.array([lambd])

                    b_re = B_RE_single(posterior_mu, prior_mu, posterior_var, lambd, 100, delta)
                    accuracy = 0

                    for i in range(n):
                        weights = sample_from_gaussian(posterior_mu, posterior_var)
                        merged_model_ = self.method.run(self.modelpool, weights[0], task_name)
                        results = self.taskpool.evaluate_validation_ada(merged_model_, task_name)
                        accuracy += results[task_name]['accuracy']/n
                    
                    accuracy = (1-accuracy) + np.sqrt(b_re/2)
                    if accuracy < best_acc:
                        best_acc = accuracy
                        best_weight = weight
                        
                merging_weights[task_name] = best_weight
            
            # Run this for off-the-shelf bound evaluation
            # for task_name in task_names:
            #     best_acc = -1e8
            #     best_weight = None
            #     for weight in weight_grids:
            #         merged_model = self.method.run(self.modelpool, weight, task_name)
            #         accuracy = self.taskpool.evaluate_validation_ada(merged_model, task_name)[task_name]['accuracy']
            #         if accuracy > best_acc:
            #             best_acc = accuracy
            #             best_weight = weight
            #             best_model = copy.deepcopy(merged_model)
            #     merging_weights[task_name] = best_weight
            #     merged_models[task_name] = best_model
            #     print(merging_weights)
                
            val_accuracy = {}
            pac_bayes = {}
            test_accuracy = {}
            for task_name in merging_weights.keys():
                scaling_factor = merging_weights[task_name]     
                modelpool_len = len(self.modelpool)-1
                posterior_mu = np.array([scaling_factor])
                posterior_var = np.array([0.05])

                prior_mu = np.array([1 / modelpool_len])
                prior_var = np.array([lambd])

                b_re = B_RE_single(posterior_mu, prior_mu, posterior_var, lambd, 100, delta)
                accuracy = 0
                val_accuracy[task_name] = 0
                test_accuracy[task_name] = 0
                for i in range(n):
                    weights = sample_from_gaussian(posterior_mu, posterior_var)
                    merged_model_ = self.method.run(self.modelpool, weights[0], task_name)
                    results = self.taskpool.evaluate_validation_ada(merged_model_, task_name)
                    val_accuracy[task_name] += results[task_name]['accuracy']/n
                    results = self.taskpool.evaluate_test_ada(merged_model_, task_name)
                    test_accuracy[task_name] += results[task_name]['accuracy']/n     
                task_dict = {}
                task_dict['test_error'] = 1.-test_accuracy[task_name]
                task_dict['val_error'] = 1.-val_accuracy[task_name]
                task_dict['pac_bayes'] = approximate_BPAC_bound(val_accuracy[task_name], b_re, niter=5)
                task_dict['b_re'] = b_re
                task_dict['upper_bound'] = 1.-val_accuracy[task_name] + np.sqrt(b_re/2)
                pac_bayes[task_name] = task_dict
            print(pac_bayes)
            if self.report_save_path is not None:
                # save report (Dict) to a file
                # if the directory of `save_report` does not exists, create it
                os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
                json.dump(pac_bayes, open(self.report_save_path, "w"))

    @rank_zero_only
    def _link_hydra_output(self):
        """
        Creates a symbolic link to the Hydra output directory within the specified log directory.

        If `self.log_dir` is not None, this method will:
        1. Retrieve the Hydra output directory using `get_hydra_output_dir()`.
        2. Create the log directory if it does not already exist.
        3. Create a symbolic link named "hydra_output_<basename_of_hydra_output_dir>"
           within the log directory, pointing to the Hydra output directory.

        Note:
            - The symbolic link is created only if the Hydra output directory is not None.
            - The `target_is_directory` parameter is set to True to indicate that the target is a directory.

        Raises:
            OSError: If the symbolic link creation fails.
        """
        if self.log_dir is not None:
            # make symlink to the hydra output directory
            try:
                hydra_output_dir = get_hydra_output_dir()
            except Exception as e:
                hydra_output_dir = None

            if hydra_output_dir is not None:
                os.makedirs(self.log_dir, exist_ok=True)
                try:
                    os.symlink(
                        hydra_output_dir,
                        os.path.join(
                            self.log_dir,
                            "hydra_output_" + os.path.basename(hydra_output_dir),
                        ),
                        target_is_directory=True,
                    )
                except OSError as e:
                    log.warning(f"Failed to create symbolic link: {e}")


def KLdiv(q, p):
    eps = 1e-8  
    q = np.clip(q, eps, 1 - eps)  
    p = np.clip(p, eps, 1 - eps)  
    return q * np.log(q / p) + (1 - q) * np.log((1 - q) / (1 - p))


def KLdiv_prime(pbar, p):
    eps = 1e-8 
    pbar = np.clip(pbar, eps, 1 - eps)
    p = np.clip(p, eps, 1 - eps)
    return (1 - pbar) / (1 - p) - pbar / p


def Newt(p, q, c):
    # Newton-Raphson update in PyTorch
    eps = 1e-8  
    kl_prime = KLdiv_prime(q, p)
    if np.abs(kl_prime) < eps:  
        kl_prime = eps if kl_prime >= 0 else -eps
    new_p = p - (KLdiv(q, p) - c) / kl_prime
    return np.clip(new_p, eps, 1 - eps)  

def approximate_BPAC_bound(train_accur, B_init, niter=5):
    b_re = B_init
    A = 1 - train_accur
    B_next = np.sqrt(b_re / 2) + A

    if B_next > 1.0:
        return 1.0

    for i in range(niter):
        B_next = Newt(B_next, A, b_re)

    return np.clip(B_next, 0.0, 1.0)  

def sample_from_gaussian(w, s):
    # w: Mean vector of shape [d]
    # s: Variance vector of shape [d] (diagonal elements of covariance matrix sI)
    # num_samples: Number of samples to draw

    # Step 1: Standard normal samples (shape: [num_samples, d])
    epsilon = np.random.randn(w.size)

    # Step 2: Reparameterization trick
    samples = w + np.sqrt(s**2) * epsilon  # Broadcasting w and sqrt(s) over epsilon

    return samples

def kl_divergence_gaussians(w, w0, s, lambd):
    """
    Calculate KL divergence between two Gaussian distributions using numpy.
    Parameters:
    - w: Mean vector of q (numpy array)
    - w0: Mean vector of p (numpy array)
    - s: Variance vector of q (numpy array)
    - lambd: Scalar variance of p

    Returns:
    - KL divergence (float)
    """

    # KL divergence calculation
    dim = w.shape[0]
    trace_term = np.sum(s / lambd)
    mean_diff = w - w0
    quadratic_term = np.dot(mean_diff, mean_diff) / lambd
    log_det_term = np.sum(np.log(lambd/s))
    kl_div = 0.5 * (trace_term + quadratic_term - dim + log_det_term)

    return kl_div

def B_RE_single(w, w0, s, lambd, m, delta):
    """
    Calculate PAC-Bayes bound using numpy for all calculations.
    Parameters:
    - w, w0, s, lambd: Inputs for KL divergence (numpy arrays)
    - b, c, m, delta: Scalar values for the PAC-Bayes bound calculation

    Returns:
    - PAC-Bayes bound (float) s
    """

    # KL divergence term
    kl_term = kl_divergence_gaussians(w, w0, s, lambd)
    log_term = np.log(m/delta)
    # Log terms
    #log_term_1 = 2 * np.log(b * np.log(c / lambd))
    #log_term_2 = np.log(np.pi**2 * m / (6 * delta))

    # Debugging prints
    print("KL term:", kl_term)
    print("Log term:", log_term)
    

    # PAC-Bayes bound calculation
    pac_bound = (kl_term + log_term) / (m - 1)

    # Debugging prints
    print("PAC-Bayes Bound:", pac_bound)

    return pac_bound