from typing import List, Tuple
from gnnboundary import *
import os

from gnnboundary.criteria.embedding_criterion_extension import ExtendedEmbeddingCriterion
from scripts.default_configs import *
from scripts.experiments_logging import initialize_logger

logger = initialize_logger("embedding.log")

class EmbeddingSpaceTrainer:
    def __init__(self,
                 embedding_space_dim,
                 key: str,
                 lr,
                 discriminator,
                 criterion,
                 scheduler,
                 optimizer,
                 dataset,
                 min=0,
                 max=1,
                 budget_penalty= None,
                 seed=1223,
                 **kwargs):
        self.seed = seed
        torch.manual_seed(self.seed)
        self.target_embedding = tnn.Parameter(torch.FloatTensor(embedding_space_dim).uniform_(min, max))
        self.discriminator = discriminator
        self.criterion = criterion
        self.budget_penalty = budget_penalty
        self.optimizer = optimizer([self.target_embedding], lr=lr) if isinstance(optimizer, type) else optimizer
        self.scheduler = scheduler(self.optimizer, **kwargs) if scheduler is not None else None
        self.dataset = dataset
        self.iteration = 0
        self.key = key


    def train(self, iterations,
              show_progress=True,
              target_probs: dict[int, tuple[float, float]] = None,

              w_budget_init=1,

              w_budget_dec=0.99,
              ):

        budget_penalty_weight = w_budget_init
        for _ in (bar := tqdm(range(iterations), initial=self.iteration, total=self.iteration + iterations,
                              disable=not show_progress)):
            self.optimizer.zero_grad()

            # Forward pass: Evaluate embedding using discriminator

            embeds = dict(embeds=self.target_embedding) if self.key == "embeds" else dict(embeds_last=self.target_embedding)
            disc_out = self.discriminator(**embeds)
            probs = disc_out["probs"]
            if target_probs is not None:
                if target_probs and all([
                    min_p <= probs[classes].item() <= max_p
                    for classes, (min_p, max_p) in target_probs.items()
                ]):
                    return True

            budget_penalty_weight = max(w_budget_init, budget_penalty_weight * w_budget_dec)

            criterion_input =  dict(logits=disc_out["logits"].unsqueeze(0), probs=probs.unsqueeze(0),)

            loss = self.criterion(criterion_input)

            # Apply budget penalty if specified
            if self.budget_penalty:
                loss += self.budget_penalty(self.target_embedding) * budget_penalty_weight

            # Backpropagation
            loss.backward()
            self.optimizer.step()

            if self.scheduler is not None:
                self.scheduler.step()

            # Logging
            bar.set_postfix({'loss': loss.item(), 'budget_penalty_weight': budget_penalty_weight})
            self.iteration += 1#

        return False

    @torch.no_grad()
    def evaluate(self,):
        self.discriminator.eval()
        embeds = dict(embeds=self.target_embedding) if self.key == "embeds" else dict(embeds_last=self.target_embedding)
        disc_out = self.discriminator(**embeds)
        return disc_out["probs"]

    def save_embedding(self, path):
        torch.save(self.target_embedding, path)

    def load_embedding(self, path):
        self.target_embedding = torch.load(path)
        self.target_embedding.requires_grad = True

    def reset_embedding(self):
        torch.manual_seed(self.seed)
        self.target_embedding = tnn.Parameter(torch.randn(self.target_embedding.shape))
        self.target_embedding.requires_grad = True

    def save_and_reset(self, path):
        self.save_embedding(path)
        self.reset_embedding()

    def save_multiple(self, path, num ):

        if not os.path.exists(path):
            os.makedirs(path)

        for i in range(num):
            self.save_and_reset(f"{path}_{i}")


def save_multiple_embeddings_to_csv(dataset, num, embed_space_dim, model, path, cls_1, cls_2, prob_range=(0.4, 0.6),
                                    min=0, max=1):
    seeds = range(num)

    path = path + "/" + dataset.name + "/" + str(prob_range[0]) + "-" + str(prob_range[1])

    if not os.path.exists(path):
        os.makedirs(path)

    embeddings = []
    for seed in seeds:
        trainer = EmbeddingSpaceTrainer(
            embedding_space_dim=embed_space_dim,
            lr=0.002,
            key="embeds_last",
            discriminator=model,
            criterion=WeightedCriterion([
                dict(key="logits", criterion=CrossEntropyBoundaryCriterion(
                    class_a=cls_1, class_b=cls_2
                ), weight=20),
            ]),
            optimizer=(o := torch.optim.SGD),
            scheduler=torch.optim.lr_scheduler.ExponentialLR,
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1),
            seed = seed,
            gamma=1,
            min=min,
            max=max
        )

        if trainer.train(iterations=20000,
                      target_probs={cls_1: prob_range,
                                    cls_2: prob_range},
                      w_budget_init=1.05,
                      w_budget_dec=0.95,):
            embeddings.append(trainer.target_embedding)

    embeddings = torch.stack(embeddings)
    pd.DataFrame(embeddings.detach().numpy()).to_csv(f"{path}/{cls_1}-{cls_2}.csv", index=False)


def create_boundary_graph(dataset, model, embedding_path, cls_1=4, cls_2=5):

    embeddings = pd.read_csv(embedding_path)
    embeddings = torch.from_numpy(embeddings.to_numpy())
    embeddings = embeddings.to(torch.float32)

    trainer = Trainer(
        sampler=(s := GraphSampler(
            max_nodes=20,
            temperature=0.2,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=True
        )),
        discriminator=model,
        criterion=WeightedCriterion([
            dict(key="logits", criterion=DynamicBalancingBoundaryCriterion((cls_1,cls_2)), weight=2),
            dict(key="embeds_last", criterion=ExtendedEmbeddingCriterion(target_embedding=embeddings,
                                                                         euclidean=True, reduction="mean"),
                 weight=1),
        ]),
        optimizer=(o := torch.optim.SGD(s.parameters(), lr=0.1)),
        scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
        dataset=dataset,
        budget_penalty=BudgetPenalty(budget=10, order=2, beta=1)
    )

    torch.manual_seed(12)
    trainer.train(
        iterations=10000,
        target_probs={cls_1: (0.4, 0.6), cls_2: (0.4, 0.6)},
        target_size=50,
        w_budget_init=1,
        w_budget_inc=1.05,
        w_budget_dec=0.99,
        k_samples=32
    )

    return trainer

def run_embedding_trainer(dataset, num_embeddings: int, cls_pairs: List[Tuple[int, int]],
                          prob_range: Tuple[float, float], min_init_range: int, max_init_range: int):
    model = GCNClassifier(**get_model_kwargs(dataset, dataset.name))
    model.load_state_dict(torch.load(CKPT_PATHS[dataset.name]))

    embed_space_dim = get_model_kwargs(dataset, dataset.name)["hidden_channels"]

    logger.info("Starting embedding training")
    for cls_1, cls_2 in cls_pairs:
        save_multiple_embeddings_to_csv(dataset, num_embeddings, embed_space_dim,
                                        model, "embeddings",
                                        cls_1, cls_2, prob_range,
                                        min=min_init_range, max=max_init_range)
        logger.info("Finished embedding training for classes %d and %d", cls_1, cls_2)

    logger.info("Finished embedding training")


