import numpy as np
import pandas as pd
import torch
import torch.nn as tnn
from tqdm.auto import tqdm, trange
from gnnboundary import *
import torch.nn.functional as F
from typing import Literal
import os


class EmbeddingSpaceTrainer:
    def __init__(self,
                 embedding_space_dim,
                 key: str,
                 lr,
                 discriminator,
                 criterion,
                 scheduler,
                 optimizer,
                 dataset,
                 budget_penalty=None,
                 **kwargs):
        torch.manual_seed(1223)
        self.target_embedding = tnn.Parameter(torch.randn(embedding_space_dim))
        self.discriminator = discriminator
        self.criterion = criterion
        self.budget_penalty = budget_penalty
        self.optimizer = optimizer([self.target_embedding], lr=lr) if isinstance(optimizer, type) else optimizer
        self.scheduler = scheduler(self.optimizer, **kwargs) if scheduler is not None else None
        self.dataset = dataset
        self.iteration = 0
        self.key = key


    def train(self, iterations,
              show_progress=True,
              target_probs: dict[int, tuple[float, float]] = None,

              w_budget_init=1,

              w_budget_dec=0.99,
              ):

        budget_penalty_weight = w_budget_init
        for _ in (bar := tqdm(range(iterations), initial=self.iteration, total=self.iteration + iterations,
                              disable=not show_progress)):
            self.optimizer.zero_grad()

            # Forward pass: Evaluate embedding using discriminator

            embeds = dict(embeds=self.target_embedding) if self.key == "embeds" else dict(embeds_last=self.target_embedding)
            disc_out = self.discriminator(**embeds)
            probs = disc_out["probs"]
            if target_probs is not None:
                if target_probs and all([
                    min_p <= probs[classes].item() <= max_p
                    for classes, (min_p, max_p) in target_probs.items()
                ]):
                    break

            budget_penalty_weight = max(w_budget_init, budget_penalty_weight * w_budget_dec)

            criterion_input =  dict(logits=disc_out["logits"].unsqueeze(0), probs=probs.unsqueeze(0),)

            loss = self.criterion(criterion_input)

            # Apply budget penalty if specified
            if self.budget_penalty:
                loss += self.budget_penalty(self.target_embedding) * budget_penalty_weight

            # Backpropagation
            loss.backward()
            self.optimizer.step()

            if self.scheduler is not None:
                self.scheduler.step()

            # Logging
            bar.set_postfix({'loss': loss.item(), 'budget_penalty_weight': budget_penalty_weight})
            self.iteration += 1

    @torch.no_grad()
    def evaluate(self,):
        self.discriminator.eval()
        disc_out = self.discriminator.out(self.target_embedding)
        return F.softmax(disc_out, dim=-1)

    def save_embedding(self, path):
        torch.save(self.target_embedding, path)

    def load_embedding(self, path):
        self.target_embedding = torch.load(path)
        self.target_embedding.requires_grad = True

    def reset_embedding(self):
        self.target_embedding = tnn.Parameter(torch.randn(self.target_embedding.shape))
        self.target_embedding.requires_grad = True

    def save_and_sample_multiple(self, n_embeddings, save_dir, **kwargs):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for i in range(n_embeddings):
            self.save_embedding(os.path.join(save_dir, f"embedding_{i}.pt"))
            self.train(**kwargs)
            self.reset_embedding()