import logging
from abc import abstractmethod
from typing import List, Mapping, Union  # noqa: F401

import lightning as L
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
import copy
from fusion_bench.compat.method import ModelFusionAlgorithm
from fusion_bench.compat.modelpool import ModelPool
from fusion_bench.models.wrappers.task_wise_fusion import (
    TaskWiseMergedModel,
    get_task_wise_weights,
)
from itertools import cycle
import torch.nn.functional as F
import nevergrad as ng
from fusion_bench.tasks.clip_classification import get_classnames_and_templates

log = logging.getLogger(__name__)

class InfiniteDataLoader:
    """
    A wrapper class for DataLoader to create an infinite data loader.
    This is useful in case we are only interested in the number of steps and not the number of epochs.

    This class wraps a DataLoader and provides an iterator that resets
    when the end of the dataset is reached, creating an infinite loop.

    Attributes:
        data_loader (DataLoader): The DataLoader to wrap.
        data_iter (iterator): An iterator over the DataLoader.
    """

    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iter = iter(data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.data_loader)  # Reset the data loader
            data = next(self.data_iter)
        return data


def entropy_loss(logits: Tensor) -> Tensor:
    """
    Compute the entropy loss of a set of logits.

    Args:
        logits (Tensor): The logits to compute the entropy loss of.

    Returns:
        Tensor: The entropy loss of the logits.
    """
    probs = torch.softmax(logits, dim=-1)
    return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()


class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
    _fabric: L.Fabric = None

    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

        if self._fabric is None and torch.cuda.is_available():
            self._fabric = L.Fabric(devices=self.config.get("devices", 1))
            self._fabric.launch()

    @torch.no_grad()
    def construct_task_wise_merged_model(self, modelpool: ModelPool):
        if self.config.weights is None:
            task_wise_weight = get_task_wise_weights(
                num_models=len(modelpool.model_names),
                init_values=self.config.init_values,
            )
        else:
            if isinstance(self.config.weights, str):
                # self.config.weights is a path to a .np or .pt file
                if self.config.weights.endswith(".pt"):
                    task_wise_weight = torch.load(
                        self.config.weights, map_location="cpu"
                    ).detach_()
                elif self.config.weights.endswith(".np"):
                    task_wise_weight = torch.from_numpy(
                        np.load(self.config.weights)
                    ).detach_()
                else:
                    raise ValueError(f"Unsupported file format: {self.config.weights}")
            else:
                try:
                    task_wise_weight = torch.tensor(
                        list(self.config.weights), dtype=torch.float32
                    )
                except ValueError:
                    raise ValueError(
                        f"Unsupported weights format: {self.config.weights}"
                    )

        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names 
        ]

        module = TaskWiseMergedModel(
            task_wise_weight=task_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )
        return module

    def sampled_models(self, modelpool, merging_weights, task_name):
        log.info("Fusing models using given weights.")
        self.modelpool = modelpool

        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names if name != task_name
        ]
        task_wise_weight = torch.from_numpy(merging_weights)
        module = TaskWiseMergedModel(
            task_wise_weight=task_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )
        print(module.merge_weight)
        return module.merge_and_unload()

    def run(self, modelpool: ModelPool, validation_sets):
        log.info("Fusing models using task-wise adaptive merging.")
        self.modelpool = modelpool

        module = self.construct_task_wise_merged_model(modelpool)
        print(module.merge_weight)
        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            modules = self.test_time_adaptation(module, validation_sets)
            if self.config.get("save_merging_weights", False):
                torch.save(module.merge_weight, self.config.save_merging_weights)
            merging_weights = {}
            for task in modules.keys():
                merging_weights[task] = modules[task].merge_weight.clone().detach()

            modules = {module_name: module.merge_and_unload() for module_name, module in modules.items()}
            return modules, merging_weights

    def run_pac(self, modelpool: ModelPool, validation_sets, delta, lambd, n):
        log.info("Fusing models using task-wise adaptive merging.")
        self.modelpool = modelpool

        #module = self.construct_task_wise_merged_model(modelpool)
        #print(module.merge_weight)
        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            merging_weights = self.test_time_adaptation_pac(validation_sets, delta, lambd, n)
            # if self.config.get("save_merging_weights", False):
            #     torch.save(module.merge_weight, self.config.save_merging_weights)
            # merging_weights = {}
            # for task in modules.keys():
            #     merging_weights[task] = modules[task].merge_weight.clone().detach()

            # modules = {module_name: module.merge_and_unload() for module_name, module in modules.items()}
            return merging_weights

    def on_test_time_adaptation_start(self):
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        pass

    @abstractmethod
    def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
        """
        Compute the logits for the given batch and task.

        Args:
            module (nn.Module): The model module.
            batch (tuple): A batch of input data.
            task (str): The name of the task.

        Returns:
            Tensor: The classification logits for the batch.
        """
        pass

    def test_time_adaptation(self, module: TaskWiseMergedModel, validation_sets):
        self.on_test_time_adaptation_start()
        modules = {}
        for task in self.modelpool.model_names:
            module_clone = copy.deepcopy(module)
            # configure optimizer
            if self.config.optimizer == "adam":
                optimizer = torch.optim.Adam([module_clone.merge_weight], lr=self.config.lr)
            else:
                raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

            if self._fabric is not None:
                module_clone, optimizer = self._fabric.setup(module_clone, optimizer)

            module_clone.train()
            module_clone.merge_weights()

            if self.config.get("fast_dev_run", False):
                log.info("Running fast_dev_run, only one step")
                pbar = tqdm(
                    range(1),
                    "AdaMerging Test-time adaptation",
                    dynamic_ncols=True,
                )
            else:
                pbar = tqdm(
                    #range(1),
                    range(self.config.max_steps),
                    "AdaMerging Test-time adaptation",
                    dynamic_ncols=True,
                )
            task_name = task.replace('-','_')
            validation_set = validation_sets[task_name]
            val_loader = DataLoader(
                validation_set,
                batch_size=self.config.batch_size,
                shuffle=True,
                num_workers=self.config.num_workers,
                pin_memory=True,
            )
            if self._fabric is not None:
                val_loader = self._fabric.setup_dataloaders(val_loader)
            val_loader = iter(InfiniteDataLoader(val_loader))

            for step_idx in pbar:
                batch = next(val_loader)
                #print(batch)
                logits = self.compute_logits(module_clone, batch, task)
                assert (
                    logits.dim() == 2
                ), f"Expected logits to be 2D, got {logits.dim()}"
                images, targets = batch  # batch가 (images, labels)로 구성되었다고 가정
                loss = F.cross_entropy(logits, targets)
                #loss = entropy_loss(logits)
                # .backward() accumulates when .zero_grad() wasn't called
                # this can save memory
                self._fabric.backward(loss, retain_graph=True)

                optimizer.step()
                optimizer.zero_grad()
                module_clone.merge_weights()
            modules[task] = module_clone
        return modules

    def test_time_adaptation_pac(self, validation_sets, delta, lambd, n):
        self.on_test_time_adaptation_start()
        merging_weights = {}
        for task in self.modelpool.model_names:
            classnames, _ = get_classnames_and_templates(task)
            num_classes = len(classnames)
            num_classes = len(classnames)
            
            pretrained_model = self.modelpool.load_model("_pretrained_")
            finetuned_models = [
                self.modelpool.load_model(name) for name in self.modelpool.model_names if task != name
            ]

            # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
            if self.config.weights is None:
                task_wise_weight = get_task_wise_weights(
                    num_models=len(finetuned_models),
                    init_values=self.config.init_values,
                )
            else:
                if isinstance(self.config.weights, str):
                    # self.config.weights is a path to a saved tensor
                    task_wise_weight = load_tensor_from_file(self.config.weights)
                else:
                    raise ValueError(f"Unsupported weights format: {self.config.weights}")

            module_clone = TaskWiseMergedModel(
                task_wise_weight=task_wise_weight,
                pretrained_model=pretrained_model,
                finetuned_models=finetuned_models,
                clamp_weights=self.config.clamp_weights,
                tie_weights=self.config.tie_weights,
                strict=self.config.strict,
            )
            
            weights_shape = module_clone.merge_weight.size()
            num_weights = 1
            for dim in weights_shape:
                num_weights *= dim
            modelpool_len = num_weights

            module_clone = self.fabric.to_device(module_clone)
            instrum = ng.p.Array(
                init=[1/weights_shape[0]] * modelpool_len,
                #upper=[1.5] * number_of_loras,
                #lower=[-1.5] * number_of_loras,
            )
            optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=self.config.max_steps)

            #module_clone.train()
            module_clone.merge_weights()

            if self.config.get("fast_dev_run", False):
                log.info("Running fast_dev_run, only one step")
                pbar = tqdm(
                    range(1),
                    "AdaMerging Test-time adaptation",
                    dynamic_ncols=True,
                )
            else:
                pbar = tqdm(
                    #range(1),
                    range(300),
                    "AdaMerging Test-time adaptation",
                    dynamic_ncols=True,
                )
            task_name = task.replace('-','_')
            try:
                validation_set = validation_sets[task_name]
            except KeyError:
                validation_set = validation_sets[f'tanganke/{task_name}']
            val_loader = DataLoader(
                validation_set,
                batch_size=self.config.batch_size,
                shuffle=True,
                num_workers=self.config.num_workers,
                pin_memory=True,
            )
            if self._fabric is not None:
                val_loader = self._fabric.setup_dataloaders(val_loader)
            val_loader = iter(InfiniteDataLoader(val_loader))

            for step_idx in pbar:
                suggestion = optimizer.ask()
                posterior_mu = suggestion.value
                posterior_var = np.array([0.05] * modelpool_len)
                prior_mu = np.array([1 / weights_shape[0]] * modelpool_len)
                prior_var = np.array([lambd] * modelpool_len)
                b_re = B_RE_single(posterior_mu, prior_mu, posterior_var, lambd, len(validation_set), delta)
                loss = 0
                
                weights_sample = []
                for i in range(n):
                    weights = sample_from_gaussian(posterior_mu, posterior_var)
                    weights = weights.reshape(weights_shape)
                    #print(images.size())
                    weights = self.fabric.to_device(torch.from_numpy(weights))
                    weights = self.fabric.to_device(nn.Parameter(weights, requires_grad=False))
                    weights_sample.append(weights)

                    
                count = 0
                while count != len(validation_set):
                    batch = next(val_loader)
                    n_batch = batch[0].size()[0]
                    count += n_batch
                    for i in range(n):
                        weights = weights_sample[i]
                        module_clone.merge_weight = weights
                        module_clone.merge_weights()
                        logits = self.compute_logits(module_clone, batch, task)
                        images, targets = batch  # batch가 (images, labels)로 구성되었다고 가정
                        predictions = torch.argmax(logits, dim=1)
                        correct = (predictions == targets).sum().item()
                        loss += correct/n
                loss = (1.-loss/count) + np.sqrt(b_re/2)
                print(loss)
                print(count)
                optimizer.tell(suggestion, loss)
            recommendation = optimizer.provide_recommendation()
            posterior_mu = recommendation.value
            weights = posterior_mu.reshape(weights_shape)
            weights = torch.from_numpy(weights).to(images.device)
            weights = nn.Parameter(weights, requires_grad=False).to(images.device)
            module_clone.merge_weight = weights
            module_clone.merge_weights()
            merging_weights[task] = module_clone.merge_weight.clone().detach()
        return merging_weights

def KLdiv(q, p):
    eps = 1e-8  
    q = np.clip(q, eps, 1 - eps)  
    p = np.clip(p, eps, 1 - eps) 
    return q * np.log(q / p) + (1 - q) * np.log((1 - q) / (1 - p))


def KLdiv_prime(pbar, p):
    eps = 1e-8  
    pbar = np.clip(pbar, eps, 1 - eps)
    p = np.clip(p, eps, 1 - eps)
    return (1 - pbar) / (1 - p) - pbar / p


def Newt(p, q, c):
    # Newton-Raphson update in PyTorch
    eps = 1e-8  
    kl_prime = KLdiv_prime(q, p)
    if np.abs(kl_prime) < eps:  
        kl_prime = eps if kl_prime >= 0 else -eps
    new_p = p - (KLdiv(q, p) - c) / kl_prime
    return np.clip(new_p, eps, 1 - eps) 


def approximate_BPAC_bound(train_accur, B_init, niter=5):
    b_re = B_init
    A = 1 - train_accur
    B_next = np.sqrt(b_re / 2) + A

    if B_next > 1.0:
        return 1.0

    for i in range(niter):
        B_next = Newt(B_next, A, b_re)

    return np.clip(B_next, 0.0, 1.0)  

def sample_from_gaussian(w, s):
    # w: Mean vector of shape [d]
    # s: Variance vector of shape [d] (diagonal elements of covariance matrix sI)
    # num_samples: Number of samples to draw

    # Step 1: Standard normal samples (shape: [num_samples, d])
    epsilon = np.random.randn(w.size)

    # Step 2: Reparameterization trick
    samples = w + np.sqrt(s**2) * epsilon  # Broadcasting w and sqrt(s) over epsilon

    return samples

def kl_divergence_gaussians(w, w0, s, lambd):
    """
    Calculate KL divergence between two Gaussian distributions using numpy.
    Parameters:
    - w: Mean vector of q (numpy array)
    - w0: Mean vector of p (numpy array)
    - s: Variance vector of q (numpy array)
    - lambd: Scalar variance of p

    Returns:
    - KL divergence (float)
    """

    # KL divergence calculation
    dim = w.shape[0]
    trace_term = np.sum(s / lambd)
    mean_diff = w - w0
    quadratic_term = np.dot(mean_diff, mean_diff) / lambd
    log_det_term = np.sum(np.log(lambd/s))
    kl_div = 0.5 * (trace_term + quadratic_term - dim + log_det_term)

    return kl_div

def B_RE(w, w0, s, lambd, b, c, m, delta):
    """
    Calculate PAC-Bayes bound using numpy for all calculations.
    Parameters:
    - w, w0, s, lambd: Inputs for KL divergence (numpy arrays)
    - b, c, m, delta: Scalar values for the PAC-Bayes bound calculation

    Returns:
    - PAC-Bayes bound (float) s
    """

    # KL divergence term
    kl_term = kl_divergence_gaussians(w, w0, s, lambd)
    
    # Log terms
    log_term_1 = 2 * np.log(b * np.log(c / lambd))
    log_term_2 = np.log(np.pi**2 * m / (6 * delta))

    # Debugging prints
    print("KL term:", kl_term)
    print("Log term 1:", log_term_1)
    print("Log term 2:", log_term_2)

    # PAC-Bayes bound calculation
    pac_bound = (kl_term + log_term_1 + log_term_2) / (m - 1)

    # Debugging prints
    print("PAC-Bayes Bound:", pac_bound)

    return pac_bound

def B_RE_single(w, w0, s, lambd, m, delta):
    """
    Calculate PAC-Bayes bound using numpy for all calculations.
    Parameters:
    - w, w0, s, lambd: Inputs for KL divergence (numpy arrays)
    - b, c, m, delta: Scalar values for the PAC-Bayes bound calculation

    Returns:
    - PAC-Bayes bound (float) s
    """

    # KL divergence term
    kl_term = kl_divergence_gaussians(w, w0, s, lambd)
    log_term = np.log(m/delta)
    # Log terms
    #log_term_1 = 2 * np.log(b * np.log(c / lambd))
    #log_term_2 = np.log(np.pi**2 * m / (6 * delta))

    # Debugging prints
    #print("KL term:", kl_term)
    #print("Log term:", log_term)
    

    # PAC-Bayes bound calculation
    pac_bound = (kl_term + log_term) / (m - 1)

    # Debugging prints
    #print("PAC-Bayes Bound:", pac_bound)

    return pac_bound