from data.loaders.cd_dataset import CDDataset
from data.collate_fns.GMAN.cd import distance_collate_fn_CD as distance_collate_fn

from config_CD import cd_config

from model.set_trainer import train_epoch, test_epoch
from model.utils import OneHotEmbedder
from model.GMAN import GMAN
from model.GMAN_Ablation import GMAN_Ablation

from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('Agg')  # before pyplot!

import torch
from torch.utils.data import DataLoader
import torch.multiprocessing

import random
import numpy as np
import os

from utils import warmup_lr_scheduler, balance_classes
    

def run_model(
        exp_config,
        gpu_train=True, 
        debug_run=False,
        save_model=False,
    ):

    device = 'cuda' if torch.cuda.is_available() and gpu_train else 'cpu'

    patient_train_files = [os.path.join(f"{exp_config.sequential_data_dir}/train__CD__patient", x) for x in os.listdir(f"{exp_config.sequential_data_dir}/train__CD__patient")]
    patient_val_files = [os.path.join(f"{exp_config.sequential_data_dir}/val__CD__patient", x) for x in os.listdir(f"{exp_config.sequential_data_dir}/val__CD__patient")]
    patient_test_files = [os.path.join(f"{exp_config.sequential_data_dir}/test__CD__patient", x) for x in os.listdir(f"{exp_config.sequential_data_dir}/test__CD__patient")]

    control_train_files = [os.path.join(f"{exp_config.sequential_data_dir}/train__CD__control", x) for x in os.listdir(f"{exp_config.sequential_data_dir}/train__CD__control")]
    control_val_files = [os.path.join(f"{exp_config.sequential_data_dir}/val__CD__control", x) for x in os.listdir(f"{exp_config.sequential_data_dir}/val__CD__control")]
    control_test_files = [os.path.join(f"{exp_config.sequential_data_dir}/test__CD__control", x) for x in os.listdir(f"{exp_config.sequential_data_dir}/test__CD__control")]

    patient_train_balanced, control_train_balanced = balance_classes(patient_train_files, control_train_files)
    patient_val_balanced, control_val_balanced = balance_classes(patient_val_files, control_val_files)
    patient_test_balanced, control_test_balanced = balance_classes(patient_test_files, control_test_files)

    train_files = patient_train_balanced + control_train_balanced
    val_files = patient_val_balanced + control_val_balanced
    test_files = patient_test_balanced + control_test_balanced

    random.shuffle(train_files) 
    random.shuffle(val_files)
    random.shuffle(test_files) 

    metadata_path = exp_config.metadata_path

    print()
    print(f"Train samples: {len(train_files)}, Val samples: {len(val_files)}, Test samples: {len(test_files)}")
    print()


    biom_one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_biom, output_dim=exp_config.num_biom_embed)
    unit_one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_units, output_dim=exp_config.num_units_embed)
    lab_code_one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_lab_ids, output_dim=exp_config.num_lab_ids_embed)

    train_dataset = CDDataset(
        file_paths=train_files,
        metadata_path= metadata_path,
        exp_config=exp_config,
        biom_one_hot_embedder = biom_one_hot_embedder,
        unit_one_hot_embedder=unit_one_hot_embedder,
        lab_code_one_hot_embedder=lab_code_one_hot_embedder,
    )

    val_dataset = CDDataset(
        file_paths=val_files,
        metadata_path= metadata_path,
        exp_config=exp_config,
        biom_one_hot_embedder = biom_one_hot_embedder,
        unit_one_hot_embedder=unit_one_hot_embedder,
        lab_code_one_hot_embedder=lab_code_one_hot_embedder,
    )

    test_dataset = CDDataset(
        file_paths=test_files,
        metadata_path= metadata_path,
        exp_config=exp_config,
        biom_one_hot_embedder = biom_one_hot_embedder,
        unit_one_hot_embedder=unit_one_hot_embedder,
        lab_code_one_hot_embedder=lab_code_one_hot_embedder,
    )

    train_loader = DataLoader(train_dataset, batch_size=exp_config.batch_size, num_workers=8, shuffle=True, collate_fn=distance_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=64, num_workers=8,shuffle=False, collate_fn=distance_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=64, num_workers=8, shuffle=False, collate_fn=distance_collate_fn)

    # Select model: standard GMAN or ablation variant with flags (mirrors P12 setup)
    use_ablation_model = (
        getattr(exp_config, 'disable_deepset', False) or 
        getattr(exp_config, 'disable_distance_embedding', False) or 
        getattr(exp_config, 'use_simple_aggregation', False) or
        getattr(exp_config, 'feature_processor_type', 'gnan') != 'gnan'
    )

    if use_ablation_model:
        model = GMAN_Ablation(
            feature_groups=exp_config.feature_groups,
            out_channels=exp_config.out_channels,
            is_graph_task=exp_config.is_graph_task,
            batch_size=exp_config.batch_size,
            n_layers=exp_config.n_layers,
            hidden_channels=exp_config.hidden_channels,
            dropout=exp_config.dropout,
            device=device,
            normalize_rho=exp_config.normalize_rho,
            max_num_GNANs=exp_config.max_num_GNANs,
            biomarker_groups=exp_config.biomarker_groups,
            gnan_mode=getattr(exp_config, 'gnan_mode', 'per_group'),
            deepset_n_layers=2,
            # Ablation flags
            disable_deepset=getattr(exp_config, 'disable_deepset', False),
            disable_distance_embedding=getattr(exp_config, 'disable_distance_embedding', False),
            use_simple_aggregation=getattr(exp_config, 'use_simple_aggregation', False),
            feature_processor_type=getattr(exp_config, 'feature_processor_type', 'gnan'),
            return_laplacian=True,
        )
    else:
        model = GMAN(
            feature_groups=exp_config.feature_groups, 
            out_channels=exp_config.out_channels, 
            is_graph_task=exp_config.is_graph_task, 
            batch_size=exp_config.batch_size,
            n_layers=exp_config.n_layers,
            hidden_channels=exp_config.hidden_channels,
            dropout=exp_config.dropout,
            device=device,
            normalize_rho=exp_config.normalize_rho,
            max_num_GNANs=exp_config.max_num_GNANs,
            biomarker_groups=exp_config.biomarker_groups,
            mix_feature_group_repres=False,
            return_laplacian=True,
            gnan_mode=getattr(exp_config, 'gnan_mode', 'per_group'),
            deepset_n_layers=3,
        )

    optimizer = torch.optim.Adam(params=model.parameters(), lr=exp_config.lr, weight_decay=exp_config.wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=0.9,
            patience=100,
            verbose=True,
            min_lr=1e-8,
        )


    last_epoch = 0

    run_name = "Group Ablations __ Immune Cell Lineages no fcal__seed12__"
    writer = SummaryWriter(log_dir=f"{exp_config.logging_dir}/{run_name}") if not debug_run else None

    print(f"Started training on {device}")
    print(f"Running experiment: {exp_config.exp_name}")
    print(f"Saving run as: {run_name}")
    print()
    print("Biomarker Groups :", exp_config.biomarker_groups)
    print()

    warmup_epochs = 5

    best_val_loss = float('inf')
    best_val_auc = float('-inf')

    for epoch in range(exp_config.epochs):

        print()
        print(f"Epoch {last_epoch + epoch}/{exp_config.epochs}")

        if epoch < warmup_epochs:
            warmup_lr_scheduler(epoch, optimizer)
        else:
            for param_group in optimizer.param_groups:
                param_group['lr'] = exp_config.lr

        train_loss, train_acc, train_auc, train_auprc = train_epoch(
            epoch=last_epoch + epoch,
            model=model,
            dloader=train_loader,
            loss_fn=torch.nn.BCEWithLogitsLoss(),
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            writer=writer,
        )

        scheduler.step(train_loss)

        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Accuracy={train_acc:.4f}, Train AUC={train_auc:.4f}, Train AUPRC={train_auprc:.4f}")

        val_loss, val_acc, val_auc, val_auprc = test_epoch(
                epoch=last_epoch + epoch,
                model=model,
                dloader=val_loader,
                loss_fn=torch.nn.BCEWithLogitsLoss(),
                device=device,
                writer=writer,
            )

        print(f"Epoch {epoch}: Val Loss={val_loss:.4f}, Val Accuracy={val_acc:.4f}, Val AUC={val_auc:.4f}, Val AUPRC={val_auprc:.4f}")
        
        test_loss, test_acc, test_auc, test_auprc = test_epoch(
                epoch=last_epoch + epoch,
                model=model,
                dloader=test_loader,
                loss_fn=torch.nn.BCEWithLogitsLoss(),
                device=device,
                writer=writer,
            )
        
        if save_model and best_val_auc < val_auc:
            best_val_auc = val_auc
            best_val_loss = val_loss
            checkpoint_save_path = os.path.join("", f"epoch_{epoch + last_epoch}__AUC{val_auc:.4f}_ACC{val_acc:.4f}.pth")

            torch.save({
                'epoch': last_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, checkpoint_save_path)

        print(f"Epoch {epoch}: Test Loss={test_loss:.4f}, Test Accuracy={test_acc:.4f}, Test AUC={test_auc:.4f}, Test AUPRC={test_auprc:.4f}")

if __name__ == '__main__':
    import warnings
    warnings.simplefilter(action='ignore', category=FutureWarning)
    
    random.seed(cd_config.SEED)
    np.random.seed(cd_config.SEED)
    torch.manual_seed(cd_config.SEED)
    torch.cuda.manual_seed_all(cd_config.SEED)
    torch.multiprocessing.set_start_method('spawn', force=True)

    run_model(
        exp_config=cd_config,
        gpu_train=True,
        debug_run=False,
        save_model=True
    )