# code for exploreaug training pipeline with DiffMorpher and classifier
import sys
import os
base_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.extend([
    base_dir,
    os.path.join(base_dir, "data"),   
    os.path.join(base_dir, "data", "datasets"),
    os.path.join(base_dir, "models"),
    os.path.join(base_dir, "models", "cartography"),
    os.path.join(base_dir, "models", "cartography", "cartography"),])
import logging
import torch
import random
import glob
import re
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset
from models.cartography.cartography.classification.run_cv import train as train_classifier
from models.cartography.cartography.classification.run_cv import validate as validate_classifier
from utils.generate_utils import dm_interpolated_images
from utils.sample_utils import select_high_entropy_indices
from utils.sample_utils import explore_latent_space_becs, get_topk_latents
from models.cartography.cartography.selection.cv_filter import run_filter
from utils.visual_utils import visualize_latent_space_with_umap
from models.DiffMorpher.converter import align_latent_spaces



def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)

def print(*args, **kwargs):
    logging.info(" ".join(map(str, args)))

class ExploreAugPipeline:
    def __init__(self, classifier, dm, data, config, writer, logdir, num, subnum, device):
        self.classifier = classifier.to(device)
        self.dm = dm.to(device) 
        self.data = data
        self.config = config
        self.writer = writer
        self.logdir = logdir
        self.num = num
        self.subnum = subnum
        self.device = device

    def early_stopping_train(self, epoch_idx=0, epoch_start=0):
        history_predictions = {}
        config_classifier = self.config.model.classifier
        best_model, epoch_last=train_classifier(self.classifier, 
                        self.data.dataloaders['train'], 
                        self.data.dataloaders['validation'],
                        epochs=config_classifier.epochs,
                        lr=config_classifier.lr,
                        factor=config_classifier.factor,
                        save_dir=os.path.join(self.logdir, "classifier_checkpoints", f"{self.num}", str(epoch_idx)),
                        history_predictions=history_predictions,
                        writer=self.writer,
                        patience=config_classifier.patience,
                        epoch_start=epoch_start)
        return best_model, epoch_last

    def get_border_samples(self, epoch_idx=0):
        if self.config.method == "becs":
            topk, topk_indices, topk_indices_dataset, topk_latents, topk_labels, all_latents, all_labels = explore_latent_space_becs(
                self.classifier,
                self.data.dataloaders['train'],
                k_per_class=self.config.k_per_class,
                entropy_weight=self.config.entropy_weight,
                device=self.device,
            )
        elif self.config.method == "cartography":
            selected_ids = run_filter(self.config.model.Cartography, 
                            save_dir=os.path.join(self.logdir, "classifier_checkpoints", self.num, str(epoch_idx)),
                            plots_dir=os.path.join(self.logdir, "data_map", self.num, str(epoch_idx)),
                            data_numbers=len(self.data.datasets['train']), 
                            data_name=self.data.datasets['train'].data.__class__.__name__,
                            model_name=self.config.model.classifier.target.split('.')[-1])
            topk_indices_dataset, topk_indices, topk_latents, topk_labels, all_latents, all_labels = get_topk_latents(self.classifier,
                                                                                                self.data.dataloaders['train'],
                                                                                                selected_ids,
                                                                                                self.device)
    
        return (
            len(topk_indices),  # topk
            topk_indices, 
            topk_indices_dataset, 
            topk_latents, 
            topk_labels, 
            all_latents, 
            all_labels
            )
       

    def generate_interpolated_samples(self, topk_labels, topk_indices_dataset, epoch_idx):
        config_dm = self.config.model.DiffMorpher
        latents_dm_encode, images_dm = dm_interpolated_images(
            topk_indices_dataset=topk_indices_dataset,
            topk_labels=topk_labels,
            data=self.data,
            dm_pipeline=self.dm,
            classifier=self.classifier,
            device=self.device,
            config_dm=config_dm,
            logdir=self.logdir,
            gen_dir=os.path.join(self.num, epoch_idx),
            number_pairs=config_dm.number_pairs,
        )
        return latents_dm_encode, images_dm

    def add_filter_samples(self, epoch_idx):
        generated_root = os.path.join(self.logdir, "generated_images", self.num, epoch_idx)
        w = self.config.w
        add_number = self.config.add_number
        new_image_paths = glob.glob(os.path.join(generated_root, "*", "*.png"))
        valid_image_paths = []
        valid_categories = []
        skipped_files = []  
        for path in new_image_paths:
            filename = os.path.basename(path)
            match = re.search(r"_(\d+)\.png$", filename)
            if match:
                category_id = int(match.group(1))
                valid_image_paths.append(path)
                valid_categories.append(category_id)
            else:

                skipped_files.append(filename)
                continue
        num_samples = min(int(add_number), len(valid_image_paths))
        if num_samples > 0:
            # selected_indices = random.sample(range(len(valid_image_paths)), num_samples)
            selected_indices=select_high_entropy_indices(self.classifier, valid_image_paths, num_samples, device=self.device)
            selected_paths = [valid_image_paths[i] for i in selected_indices]
            selected_categories = [valid_categories[i] for i in selected_indices]
            self.data.datasets['train'].data.add_samples(
                selected_paths,
                selected_categories,
                weight_factor=float(w)
            )

        self.data.train_dataloader()
        print(f"updated train_dataset: {len(self.data.datasets['train'])}")

    def run(self, iterations=1):
        best_model, epoch_last = self.early_stopping_train()
        self.classifier = best_model


        for i in range(iterations):

            (topk, 
            topk_indices, 
            topk_indices_dataset, 
            topk_latents, 
            topk_labels, 
            all_latents,
            all_labels)= self.get_border_samples(str(i))
            
            
            visualize_latent_space_with_umap(all_latents, all_labels, 
                                    os.path.join(self.logdir, "topk_latent_space"), 
                                    topk_indices=topk_indices, # random.sample(topk_indices, num_samples) if len(topk_indices) > num_samples else topk_indices,
                                    num=f"{self.num}_{i+1}",)
            align_latent_spaces(self.classifier,
                                self.dm,
                                self.data.dataloaders['train'],)
            latents_dm_encode, images_dm = self.generate_interpolated_samples(topk_labels, topk_indices_dataset, str(i+1))

            highlight_label = -1
            combined_latents = np.concatenate([all_latents, latents_dm_encode], axis=0)
            combined_labels = np.concatenate([all_labels, np.full(len(latents_dm_encode), highlight_label)], axis=0)
            visualize_latent_space_with_umap(combined_latents, combined_labels, 
                                            os.path.join(self.logdir,"latent_space_with_dm_test"), 
                                            topk_indices=topk_indices,
                                            highlight_label=-1,
                                            num=f"{self.num}_{i+1}",)
    
            self.add_filter_samples(epoch_idx=str(i+1))
            
            # Re-train classifier on extended dataset
            best_model, epoch_last_0 = self.early_stopping_train(str(i+1), epoch_start=epoch_last)
            epoch_last = epoch_last_0 + epoch_last
            self.classifier = best_model
            val_acc = validate_classifier(self.classifier, self.data.dataloaders['validation'], self.device)
            print(f"Validation accuracy after iteration {i+1}: {val_acc:.2f}%")
