import os
import json
import logging
import torch
import torch.nn as nn
from statistics import mean
from torch.utils.data import DataLoader
import seaborn as sns; sns.set_theme()
import matplotlib.pyplot as plt

from llm_router.data.utils import RoutingDataCollator

logger = logging.getLogger(__name__)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    logger.info(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )
    
def filter_knn_indices(knn_indices, batch_idx):
    knn_indices = knn_indices.tolist()
    batch_idx = batch_idx.tolist()
    filtered_knn_indices = []
    for indices, idx in zip(knn_indices, batch_idx):
        if idx in indices:
            indices.remove(idx)
        else:
            indices = indices[:-1]
        filtered_knn_indices.append(indices)
    
    return torch.tensor(filtered_knn_indices).long()

def sample_from_gmm(pi_logits, mu, log_var, num_samples=1):
    """
    Sample from a GMM given mixture parameters.
    
    Args:
        pi_logits: (B, K) - unnormalized logits for mixture weights
        mu:        (B, K) - means for each component
        log_var:   (B, K) - log variances for each component
        num_samples: int  - number of samples to draw per batch item

    Returns:
        samples: (B, num_samples)
    """
    B, K = pi_logits.shape
    pi = torch.softmax(pi_logits, dim=-1)  # Convert logits to probabilities

    # Sample mixture components for each batch and sample
    cat_dist = torch.distributions.Categorical(pi)
    component_ids = cat_dist.sample((num_samples,))  # (num_samples, B)

    # Gather corresponding means and stds
    mu_expanded = mu.unsqueeze(0).expand(num_samples, B, K)           # (num_samples, B, K)
    std_expanded = torch.exp(0.5 * log_var).unsqueeze(0).expand_as(mu_expanded)

    # Select the component for each sample
    selected_mu = torch.gather(mu_expanded, 2, component_ids.unsqueeze(-1))  # (num_samples, B, 1)
    selected_std = torch.gather(std_expanded, 2, component_ids.unsqueeze(-1))

    # Sample from the selected Gaussian
    samples = selected_mu + selected_std * torch.randn_like(selected_mu)  # (num_samples, B, 1)

    return samples.squeeze(-1).transpose(0, 1)

def plot_loss(losses, x_label, save_name):
    ax = sns.lineplot(losses)
    ax.set_xlabel(x_label)
    ax.set_ylabel("Loss")
    ax.get_figure().savefig(save_name, bbox_inches='tight')
    plt.close(ax.get_figure())
    
class Trainer:
    def __init__(self, model, trainset, valset, training_config):
        self.model = model
        self.trainset = trainset
        self.valset = valset
        self.training_config = training_config
        
    def save(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        torch.save(self.model.state_dict(), os.path.join(output_dir, f"{self.model.name}.pth"))
    
    def load(self, output_dir):
        state_dict = torch.load(os.path.join(output_dir, f"{self.model.name}.pth"))
        self.model.load_state_dict(state_dict)
    
    @torch.no_grad()
    def validate(self):
        val_dataloader = DataLoader(self.valset, batch_size=self.training_config.batch_size, shuffle=False, collate_fn=RoutingDataCollator())
        losses = []
        for i, batch in enumerate(val_dataloader):
            loss = self.model(batch)
            losses.append(loss.item())
        
        return mean(losses)
        
    def train(self):
        train_dataloader = DataLoader(self.trainset, batch_size=self.training_config.batch_size, shuffle=True, collate_fn=RoutingDataCollator())
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.training_config.lr)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.training_config.lr_decay_rate)
    
        best_loss = float("inf")
        batch_loss_curve = []
        epoch_loss_curve = []
        eval_loss_curve = []
        for epoch in range(1, self.training_config.epochs+1):
            losses = []
            for i, batch in enumerate(train_dataloader):
                optimizer.zero_grad()
                loss = self.model(batch)
                loss.backward()
                optimizer.step()
                if self.training_config.grad_norm is not None:
                    nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.training_config.grad_norm)
                losses.append(loss.item())
                if self.training_config.log_freq > 0 and i % self.training_config.log_freq == 0:
                    logger.info(f"batch {i} loss: {mean(losses)}")
                    batch_loss_curve.append(mean(losses))
            logger.info(f"epoch {epoch} loss: {mean(losses)} lr: {lr_scheduler.get_last_lr()}")
            epoch_loss_curve.append(mean(losses))
            plot_loss(epoch_loss_curve, "Epoch", os.path.join(self.training_config.output_dir, "epoch_loss.png"))
            plot_loss(batch_loss_curve, "Batch", os.path.join(self.training_config.output_dir, "batch_loss.png"))
            val_loss = self.validate()
            eval_loss_curve.append(val_loss)
            plot_loss(eval_loss_curve, "Epoch", os.path.join(self.training_config.output_dir, "eval_loss.png"))
            if val_loss < best_loss:
                self.save(os.path.join(self.training_config.output_dir, "model/best"))
                best_loss = val_loss
            if self.training_config.save_freq > 0 and epoch % self.training_config.save_freq == 0:
                self.save(os.path.join(self.training_config.output_dir, f"model/epoch{epoch}"))
            lr_scheduler.step()
        self.save(os.path.join(self.training_config.output_dir, "model/last"))