import logging
import os
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, cast  # noqa: F401
import nevergrad as ng
import torch
from lightning.fabric.utilities.rank_zero import rank_zero_only
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from torchmetrics import Accuracy
from fusion_bench.compat.method import ModelFusionAlgorithm
from fusion_bench.compat.modelpool import ModelPool
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
from fusion_bench.models.wrappers.layer_wise_fusion import (
    LayerWiseMergedModel,
    get_layer_wise_weights,
)
from fusion_bench.utils.data import load_tensor_from_file
from fusion_bench.utils.type import TorchModelType
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
from .entropy_loss import entropy_loss
from .utils import get_memory_usage
import copy
import torch.nn as nn
import torch.nn.functional as F
if TYPE_CHECKING:
    from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
from torchmetrics.classification.accuracy import MulticlassAccuracy
import numpy as np
log = logging.getLogger(__name__)

class InfiniteDataLoader:
    """
    A wrapper class for DataLoader to create an infinite data loader.
    This is useful in case we are only interested in the number of steps and not the number of epochs.

    This class wraps a DataLoader and provides an iterator that resets
    when the end of the dataset is reached, creating an infinite loop.

    Attributes:
        data_loader (DataLoader): The DataLoader to wrap.
        data_iter (iterator): An iterator over the DataLoader.
    """

    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iter = iter(data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.data_loader)  # Reset the data loader
            data = next(self.data_iter)
        return data

class InfiniteDataLoader:
    """
    A wrapper class for DataLoader to create an infinite data loader.
    This is useful in case we are only interested in the number of steps and not the number of epochs.

    This class wraps a DataLoader and provides an iterator that resets
    when the end of the dataset is reached, creating an infinite loop.

    Attributes:
        data_loader (DataLoader): The DataLoader to wrap.
        data_iter (iterator): An iterator over the DataLoader.
    """

    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iter = iter(data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.data_loader)  # Reset the data loader
            data = next(self.data_iter)
        return data

class LayerWiseAdaMergingAlgorithm(
    ModelFusionAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):
    _program: "FabricModelFusionProgram"
    """The program that this algorithm is running on."""

    """
    Implements the Layer-Wise AdaMerging Algorithm.

    This class merges the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
    """

    def __init__(self, algorithm_config: DictConfig):
        """
        Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.

        Args:
            algorithm_config (DictConfig): The configuration for the algorithm.
        """
        super().__init__(algorithm_config)

    @torch.no_grad()
    def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
        """
        Constructs a wrapped layer-wise merged model from model pool.

        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
        The merging weights can be initialized based on a provided configuration or loaded from a file.

        Args:
            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

        Returns:
            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
        """
        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
        if self.config.weights is None:
            layer_wise_weight = get_layer_wise_weights(
                num_models=len(modelpool.model_names),
                num_layers=len(
                    tuple(
                        filter(lambda p: p.requires_grad, pretrained_model.parameters())
                    )
                ),
                init_values=self.config.init_values,
            )
        else:
            if isinstance(self.config.weights, str):
                # self.config.weights is a path to a saved tensor
                layer_wise_weight = load_tensor_from_file(self.config.weights)
            else:
                raise ValueError(f"Unsupported weights format: {self.config.weights}")

        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )
        print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
        return module

    @rank_zero_only
    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
        """
        Save the merging weights to a file.

        Args:
            file_path (str): The path to save the merging weights.
            merging_weights (torch.Tensor): The merging weights to save.
        """
        if self.fabric.is_global_zero and self.config.get(
            "save_merging_weights", False
        ):
            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
                # if the file path is not absolute or relative to current working directory, save it in the log directory
                save_path = os.path.join(self.log_dir, file_path)
            else:
                save_path = file_path
            log.info(f"saving merging weights to {save_path}.")
            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merging_weights.detach().cpu(), save_path)

    def run(self, modelpool: ModelPool, validation_sets, **kwargs):
        """
        Run the Layer-Wise AdaMerging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            LayerWiseMergedModel: The merged model after test-time adaptation.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool
        self.log_hyperparams(self.config)

        with self.profile("construct the wrapped model"):
            module = self.construct_layer_wise_merged_model(modelpool)

        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            with self.profile("test-time adaptation"):
                modules = self.test_time_adaptation(module, validation_sets)
            if self.config.get("save_merging_weights", False):
                self.save_merging_weights(
                    self.config.save_merging_weights, module.merge_weight
                )
            merging_weights = {}
            for task in modules.keys():
                merging_weights[task] = modules[task].merge_weight.clone().detach()

            modules = {module_name: module.merge_and_unload() for module_name, module in modules.items()}
            return modules, merging_weights

    def run_pac(self, modelpool: ModelPool, validation_sets, delta, lambd, n, **kwargs):
        """
        Run the Layer-Wise AdaMerging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            LayerWiseMergedModel: The merged model after test-time adaptation.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool
        self.log_hyperparams(self.config)

        # with self.profile("construct the wrapped model"):
        #     module = self.construct_layer_wise_merged_model(modelpool)

        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            with self.profile("test-time adaptation"):
                merging_weights = self.test_time_adaptation_pac(validation_sets, delta, lambd, n)
            # if self.config.get("save_merging_weights", False):
            #     self.save_merging_weights(
            #         self.config.save_merging_weights, module.merge_weight
            #     )
            
            return merging_weights

    def sampled_models(self, modelpool, merging_weights, task_name):
        log.info("Fusing models using given weights.")
        self.modelpool = modelpool

        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names if name != task_name
        ]
        task_wise_weight = torch.from_numpy(merging_weights)
        module = LayerWiseMergedModel(
            layer_wise_weight=task_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )

        print(module.merge_weight)
        return module.merge_and_unload()

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.

        Args:
            task (str): The name of the task.

        Returns:
            DataLoader: The data loader for the test dataset.
        """
        pass

    @abstractmethod
    def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module: The model module.
            images (Tensor): The input images.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        pass

    def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]", validation_sets):
        """
        Perform test-time adaptation on the merged model.

        This method adapts the merging weights during test-time to improve performance.

        Args:
            module (LayerWiseMergedModel): The merged model.

        Returns:
            LayerWiseMergedModel: The adapted merged model.
        """
        self.on_test_time_adaptation_start()
        accuracy: MulticlassAccuracy = Accuracy(
            task="multiclass", num_classes=self.num_classes
        )
        modules = {}
        for task in self.modelpool.model_names:
            module_clone = copy.deepcopy(module)
            # configure optimizer
            if self.config.optimizer == "adam":
                optimizer = torch.optim.Adam([module_clone.merge_weight], lr=self.config.lr)
                print(f"{optimizer=}")
                module_clone, optimizer = self.fabric.setup(module_clone, optimizer)
            else:
                raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

            module_clone.train()
            module_clone.merge_weights()

            task_name = task.replace('-','_')
            validation_set = validation_sets[task_name]
            val_loader = DataLoader(
                validation_set,
                batch_size=self.config.batch_size,
                shuffle=True,
                num_workers=self.config.num_workers,
                pin_memory=True,
            )
            if self.fabric is not None:
                val_loader = self.fabric.setup_dataloaders(val_loader)
            val_loader = iter(InfiniteDataLoader(val_loader))

            for step_idx in (
                pbar := tqdm(
                    #range(1),
                    range(self.config.max_steps if not self.is_debug_mode else 1),
                    ("[DEBUG MODE] " if self.is_debug_mode else "")
                    + "AdaMerging Test-time adaptation",
                    dynamic_ncols=True,
                )
            ):
                # default behavior for first-order optimizers
                
                with self.profile("data loading"):
                    batch = next(val_loader)
                with self.profile("forward pass"):
                    logits = self.compute_logits(module_clone, batch[0], task)
                    images, targets = batch  # batch가 (images, labels)로 구성되었다고 가정
                    loss = F.cross_entropy(logits, targets)
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

                with self.profile("optimizer step"):
                    optimizer.step()
                    optimizer.zero_grad()
                with self.profile("merging weights"):
                    module_clone.merge_weights()

                metrics = {
                    "train/loss": loss.item(),
                    "train/weight_max": module.merge_weight.max().item(),
                    "train/weight_min": module.merge_weight.min().item(),
                    "train/weight_mean": module.merge_weight.mean().item(),
                }
                self.fabric.log_dict(metrics, step=step_idx)
                pbar.set_postfix(metrics)
            modules[task] = module_clone
        log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
        self.print_profile_summary()
        return modules

    @torch.no_grad()
    def test_time_adaptation_pac(self, validation_sets, delta, lambd, n):
        """
        Perform test-time adaptation on the merged model.

        This method adapts the merging weights during test-time to improve performance.

        Args:
            module (LayerWiseMergedModel): The merged model.

        Returns:
            LayerWiseMergedModel: The adapted merged model.
        """
        self.on_test_time_adaptation_start()
        merging_weights = {}
        for task in self.modelpool.model_names:
            classnames, _ = get_classnames_and_templates(task)
            num_classes = len(classnames)
            #module_clone = copy.deepcopy(module)
            
            pretrained_model = self.modelpool.load_model("_pretrained_")
            finetuned_models = [
                self.modelpool.load_model(name) for name in self.modelpool.model_names if task != name
            ]

            # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
            if self.config.weights is None:
                layer_wise_weight = get_layer_wise_weights(
                    num_models=len(finetuned_models),
                    num_layers=len(
                        tuple(
                            filter(lambda p: p.requires_grad, pretrained_model.parameters())
                        )
                    ),
                    init_values=self.config.init_values,
                )
            else:
                if isinstance(self.config.weights, str):
                    # self.config.weights is a path to a saved tensor
                    layer_wise_weight = load_tensor_from_file(self.config.weights)
                else:
                    raise ValueError(f"Unsupported weights format: {self.config.weights}")

            module_clone = LayerWiseMergedModel(
                layer_wise_weight=layer_wise_weight,
                pretrained_model=pretrained_model,
                finetuned_models=finetuned_models,
                clamp_weights=self.config.clamp_weights,
                tie_weights=self.config.tie_weights,
                strict=self.config.strict,
            )
            
            weights_shape = module_clone.merge_weight.size()
            num_weights = 1
            for dim in weights_shape:
                num_weights *= dim
            modelpool_len = num_weights
            
            module_clone = self.fabric.to_device(module_clone)
            instrum = ng.p.Array(
                init=[1/weights_shape[0]] * modelpool_len,
                #upper=[1.5] * number_of_loras,
                #lower=[-1.5] * number_of_loras,
            )
            optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=self.config.max_steps)
            module_clone.merge_weights()
            
            task_name = task.replace('-','_')
            try:
                validation_set = validation_sets[task_name]
            except KeyError:
                validation_set = validation_sets[f'tanganke/{task_name}']
            val_loader = DataLoader(
                validation_set,
                batch_size=self.config.batch_size,
                shuffle=True,
                num_workers=self.config.num_workers,
                pin_memory=True,
            )
            if self.fabric is not None:
                val_loader = self.fabric.setup_dataloaders(val_loader)
            val_loader = iter(InfiniteDataLoader(val_loader))

            for step_idx in (
                pbar := tqdm(
                    #range(1),
                    range(300 if not self.is_debug_mode else 1),
                    ("[DEBUG MODE] " if self.is_debug_mode else "")
                    + "AdaMerging Test-time adaptation",
                    dynamic_ncols=True,
                )
            ):
                # default behavior for first-order optimizers
                is_last = False
                suggestion = optimizer.ask()
                posterior_mu = suggestion.value
                posterior_var = np.array([0.05] * modelpool_len)
                prior_mu = np.array([1 / weights_shape[0]] * modelpool_len)
                prior_var = np.array([lambd] * modelpool_len)
                b_re = B_RE_single(posterior_mu, prior_mu, posterior_var, lambd, len(validation_set), delta)
                loss = 0
                
                weights_sample = []
                for i in range(n):
                    weights = sample_from_gaussian(posterior_mu, posterior_var)
                    weights = weights.reshape(weights_shape)
                    #print(images.size())
                    weights = self.fabric.to_device(torch.from_numpy(weights))
                    weights = self.fabric.to_device(nn.Parameter(weights, requires_grad=False))
                    weights_sample.append(weights)
                
                count = 0
                while count != len(validation_set):
                    with self.profile("data loading"):
                        batch = next(val_loader)
                    n_batch = batch[0].size()[0]
                    count += n_batch
                    for i in range(n):
                        weights = weights_sample[i]
                        module_clone.merge_weight = weights
                        module_clone.merge_weights()
                        logits = self.compute_logits(module_clone, batch[0], task)
                        images, targets = batch  # batch가 (images, labels)로 구성되었다고 가정
                        predictions = torch.argmax(logits, dim=1)
                        correct = (predictions == targets).sum().item()
                        loss += correct/n
                loss = (1.-loss/count) + np.sqrt(b_re/2)
                print(loss)
                print(count)
                optimizer.tell(suggestion, loss)
                
                metrics = {
                    "train/loss": loss,
                    "train/weight_max": module_clone.merge_weight.max().item(),
                    "train/weight_min": module_clone.merge_weight.min().item(),
                    "train/weight_mean": module_clone.merge_weight.mean().item(),
                }
                self.fabric.log_dict(metrics, step=step_idx)
                pbar.set_postfix(metrics)
            recommendation = optimizer.provide_recommendation()
            posterior_mu = recommendation.value
            weights = posterior_mu.reshape(weights_shape)
            weights = torch.from_numpy(weights).to(images.device)
            weights = nn.Parameter(weights, requires_grad=False).to(images.device)
            module_clone.merge_weight = weights
            module_clone.merge_weights()
            merging_weights[task] = module_clone.merge_weight.clone().detach()
  
        log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
        self.print_profile_summary()
        return merging_weights

def KLdiv(q, p):
    eps = 1e-8 
    q = np.clip(q, eps, 1 - eps)  
    p = np.clip(p, eps, 1 - eps)  
    return q * np.log(q / p) + (1 - q) * np.log((1 - q) / (1 - p))


def KLdiv_prime(pbar, p):
    eps = 1e-8  
    pbar = np.clip(pbar, eps, 1 - eps)
    p = np.clip(p, eps, 1 - eps)
    return (1 - pbar) / (1 - p) - pbar / p


def Newt(p, q, c):
    # Newton-Raphson update in PyTorch
    eps = 1e-8 
    kl_prime = KLdiv_prime(q, p)
    if np.abs(kl_prime) < eps:  
        kl_prime = eps if kl_prime >= 0 else -eps
    new_p = p - (KLdiv(q, p) - c) / kl_prime
    return np.clip(new_p, eps, 1 - eps) 


def approximate_BPAC_bound(train_accur, B_init, niter=5):
    b_re = B_init
    A = 1 - train_accur
    B_next = np.sqrt(b_re / 2) + A

    if B_next > 1.0:
        return 1.0

    for i in range(niter):
        B_next = Newt(B_next, A, b_re)

    return np.clip(B_next, 0.0, 1.0) 

def sample_from_gaussian(w, s):
    # w: Mean vector of shape [d]
    # s: Variance vector of shape [d] (diagonal elements of covariance matrix sI)
    # num_samples: Number of samples to draw

    # Step 1: Standard normal samples (shape: [num_samples, d])
    epsilon = np.random.randn(w.size)

    # Step 2: Reparameterization trick
    samples = w + np.sqrt(s**2) * epsilon  # Broadcasting w and sqrt(s) over epsilon

    return samples

def kl_divergence_gaussians(w, w0, s, lambd):
    """
    Calculate KL divergence between two Gaussian distributions using numpy.
    Parameters:
    - w: Mean vector of q (numpy array)
    - w0: Mean vector of p (numpy array)
    - s: Variance vector of q (numpy array)
    - lambd: Scalar variance of p

    Returns:
    - KL divergence (float)
    """

    # KL divergence calculation
    dim = w.shape[0]
    trace_term = np.sum(s / lambd)
    mean_diff = w - w0
    quadratic_term = np.dot(mean_diff, mean_diff) / lambd
    log_det_term = np.sum(np.log(lambd/s))
    kl_div = 0.5 * (trace_term + quadratic_term - dim + log_det_term)

    return kl_div

def B_RE(w, w0, s, lambd, b, c, m, delta):
    """
    Calculate PAC-Bayes bound using numpy for all calculations.
    Parameters:
    - w, w0, s, lambd: Inputs for KL divergence (numpy arrays)
    - b, c, m, delta: Scalar values for the PAC-Bayes bound calculation

    Returns:
    - PAC-Bayes bound (float) s
    """

    # KL divergence term
    kl_term = kl_divergence_gaussians(w, w0, s, lambd)
    
    # Log terms
    log_term_1 = 2 * np.log(b * np.log(c / lambd))
    log_term_2 = np.log(np.pi**2 * m / (6 * delta))

    # Debugging prints
    #print("KL term:", kl_term)
    #print("Log term 1:", log_term_1)
    #print("Log term 2:", log_term_2)

    # PAC-Bayes bound calculation
    pac_bound = (kl_term + log_term_1 + log_term_2) / (m - 1)

    # Debugging prints
    #print("PAC-Bayes Bound:", pac_bound)

    return pac_bound

def B_RE_single(w, w0, s, lambd, m, delta):
    """
    Calculate PAC-Bayes bound using numpy for all calculations.
    Parameters:
    - w, w0, s, lambd: Inputs for KL divergence (numpy arrays)
    - b, c, m, delta: Scalar values for the PAC-Bayes bound calculation

    Returns:
    - PAC-Bayes bound (float) s
    """

    # KL divergence term
    kl_term = kl_divergence_gaussians(w, w0, s, lambd)
    log_term = np.log(m/delta)
    # Log terms
    #log_term_1 = 2 * np.log(b * np.log(c / lambd))
    #log_term_2 = np.log(np.pi**2 * m / (6 * delta))

    # Debugging prints
    #print("KL term:", kl_term)
    #print("Log term:", log_term)
    

    # PAC-Bayes bound calculation
    pac_bound = (kl_term + log_term) / (m - 1)

    # Debugging prints
    #print("PAC-Bayes Bound:", pac_bound)

    return pac_bound