import os
import os.path as osp
import yaml
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from gnn.GCN import GCN
from gnn.GIN import GIN
from evaluation.evaluate_auc import evaluate_auc, evaluate_auc_class
from .explainer_utils import create_edge_embeds, sample_graph
from utils import get_dataset
from classifier import GuidedExplainer, GuidedExplainerGIN
from classification_loss import polar_loss

from .distribution_analysis import DistributionAnalyzer

from evaluation.evaluate_f1 import evaluate_f1
from evaluation.evaluate_f1 import evaluate_bimod_binar

def reset_parameters(module, dataset):
    if isinstance(module, nn.Linear):
        if dataset == "BA3":
            nn.init.xavier_normal_(module.weight)
        else:
            nn.init.xavier_normal_(module.weight, gain=0.1)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

class sssgexplainer():
    def __init__(self, args, device):
        self.dataset = args.dataset
        self.device = device
        self._load_dataset_config()

        self.explainer = nn.Sequential(
            nn.Linear(args.hidden * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ).to(self.device)
        self.explainer.apply(lambda module: reset_parameters(module, self.dataset))
        
        self.sample_bias = 0.0
        self.optimizer_explainer = optim.Adam(self.explainer.parameters(), lr=args.lr)

    def _load_dataset_config(self):
        config_path = os.path.join("configs", f"{self.dataset}.yaml")
        with open(config_path, "r") as f:
            dataset_config = yaml.safe_load(f)
        self.in_dim = dataset_config['in_dim']
        self.num_cls = dataset_config['num_cls']
        self.pool_type = dataset_config['pool_type']
        self.use_jk = dataset_config['use_jk']

    def _loss(self, pred, target, mask, batch_size, reg_coefs):
        scale = 0.99
        mask = mask * (2 * scale - 1.0) + (1.0 - scale)
        
        cce_loss = F.cross_entropy(pred, target, reduction='mean')
        size_loss = reg_coefs[0] * torch.sum(mask) / batch_size
        mask_ent_reg = -mask * torch.log(mask + 1e-8) - (1 - mask) * torch.log(1 - mask + 1e-8)
        mask_ent_loss = reg_coefs[1] * torch.sum(mask_ent_reg) / batch_size

        return cce_loss + size_loss + mask_ent_loss
    
    def train_test(self, args):
        model = args.model
        nlayers = args.nlayers
        hidden = args.hidden
        dropout = args.dropout
        
        round1_epochs = args.round1_epochs
        round1_lr = args.round1_lr
        
        print("Class hyperparameter information")
        print(args.__dict__)
        
        dataset = self.dataset
        data_path = f"data/{self.dataset}"
        train_dataset, val_dataset, test_dataset, num_cls = get_dataset(data_path, self.dataset)

        fc_only_pos = bool(getattr(args, 'fc_only_pos', False))
        if self.dataset == "FC" and fc_only_pos:
            def _is_pos_graph(g):
                y = getattr(g, 'y', None)
                if y is None:
                    return False
                if isinstance(y, torch.Tensor):
                    if y.numel() == 0:
                        return False
                    return int(y.view(-1)[0].item()) == 1
                try:
                    return int(y) == 1
                except Exception:
                    return False
            train_dataset = [g for g in train_dataset if _is_pos_graph(g)]
            val_dataset   = [g for g in val_dataset   if _is_pos_graph(g)]
            test_dataset  = [g for g in test_dataset  if _is_pos_graph(g)]
        
        if args.model == 'GCN':
            model = GCN(train_dataset[0].x.shape[1], self.num_cls, nlayers, hidden, dropout, self.pool_type, self.use_jk)
        elif args.model == 'GIN':
            model = GIN(train_dataset[0].x.shape[1], self.num_cls, nlayers, hidden, dropout, self.pool_type)
            
        model_path = f"param/{model}"
        model.load_state_dict(torch.load(osp.join(model_path, f"{dataset}_{model}_best_val.pth"), map_location=self.device))        
        model.to(self.device)
        model.eval()
        
        train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader  = DataLoader(test_dataset,  batch_size=len(test_dataset), shuffle=False)
        
        best_auc = 0.0
        best_state = None
        save_path = f'param/{self.dataset}/{args.explainer_name}'
        os.makedirs(save_path, exist_ok=True)
        # Seed-specific directory and base checkpoint path (no fallback)
        seed_dir = osp.join(save_path, str(args.seed))
        os.makedirs(seed_dir, exist_ok=True)

        base_explainer_path = f'param/{self.dataset}/baseexplainer/{args.seed}/base_explainer.pth'
        
        if osp.exists(base_explainer_path):
            print(f"Loading pre-trained base explainer from {base_explainer_path}")
            self.explainer.load_state_dict(torch.load(base_explainer_path, map_location=self.device))
            
            val_auc = evaluate_auc(val_loader, self.explainer, model, self.device, training=False)
            print(f"Loaded base explainer validation AUC: {val_auc:.4f}")
            best_auc = val_auc
            best_state = self.explainer.state_dict()
        
        if best_state is not None:
            self.explainer.load_state_dict(best_state)
            print(f"==> [base End] Best Explainer restored (Val AUC={best_auc:.4f})")
                
        try:
            self.pseudo_alpha0 = float(getattr(args, 'alpha0', None) or 0.10)
            self.pseudo_skew_c = float(getattr(args, 'c', None) or 0.50)
        except Exception:
            self.pseudo_alpha0 = 0.10
            self.pseudo_skew_c = 0.50
        print(f"Quantile pseudo-labeling params: alpha0={self.pseudo_alpha0}, c={self.pseudo_skew_c}")
        
        if self.dataset == "FC":                   
            self.explainer_round1 = GuidedExplainerGIN(self.in_dim, hidden_dim = hidden, out_channels = 2).to(self.device)
        else:
            self.explainer_round1 = GuidedExplainer(self.in_dim, hidden_dim = hidden, out_channels = 2).to(self.device)
        self.explainer_round1.apply(lambda module: reset_parameters(module, self.dataset))

        self.optimizer_explainer_round1 = torch.optim.Adam(self.explainer_round1.parameters(), lr=round1_lr)

        best_round1_auc = 0.0
        best_round1_state = None
        patience = 30
        epochs_no_improve = 0

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)

        pseudo_labeled_data_list = self._generate_pseudo_labeled_dataset(train_loader, model)
        new_train_loader = DataLoader( pseudo_labeled_data_list, batch_size=args.batch_size, shuffle=False)

        best_round1_auc = 0.0
        best_round1_state = None
        epochs_no_improve = 0
        for epoch in range(round1_epochs):
            self.explainer_round1.train()
            epoch_loss_r1 = 0.0

            for new_batch in new_train_loader:
                new_batch = new_batch.to(self.device)
                prob = self.explainer_round1(new_batch)

                labeled_edge_mask = (new_batch.edge_gt.squeeze() != -1.0)
                if labeled_edge_mask.sum() == 0:
                    continue

                loss = polar_loss(prob[labeled_edge_mask], new_batch.edge_gt[labeled_edge_mask], pos_weight=args.w)
                self.optimizer_explainer_round1.zero_grad()
                loss.backward()
                self.optimizer_explainer_round1.step()

                epoch_loss_r1 += loss.item()

            val_auc_r1 = evaluate_auc_class(val_loader, self.explainer_round1, device=self.device)
            print(f"[round1 | epoch {epoch}/{round1_epochs}] loss={epoch_loss_r1:.4f}, ValAUC={val_auc_r1:.4f}")

            if val_auc_r1 > best_round1_auc:
                best_round1_auc = val_auc_r1
                best_round1_state = {k: v.clone() for k, v in self.explainer_round1.state_dict().items()}

                ckpt_dir = osp.join(save_path, str(args.seed))
                os.makedirs(ckpt_dir, exist_ok=True)
                ckpt_path = osp.join(ckpt_dir, f"{dataset}_{args.explainer_name}_best_model.pth")
                torch.save(best_round1_state, ckpt_path)
                print(f"  >> [round1] Best Explainer updated (AUC={best_round1_auc:.4f}) - saved to {ckpt_path}")
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"  >> Early stopping triggered (no improvement for {patience} epochs).")
                    break

        if best_round1_state is not None:
            self.explainer_round1.load_state_dict(best_round1_state)
            print(f"==> [round1 End] Best restored (Val AUC={best_round1_auc:.4f})")
    
        print("\n===== AUC score!! =====")
        test_auc = evaluate_auc_class(test_loader, self.explainer_round1, device=self.device )
        print(f"AUC of test dataset: {test_auc:.4f}")
        
        print("\n===== round 2: Second Guided Explainer (fresh pseudo-labels from Guided scores) =====")

        print("Collecting guided edge score distribution from round 1 model...")
        guided_scores = []
        with torch.no_grad():
            self.explainer_round1.eval()
            for batch in DataLoader(train_dataset, batch_size=16, shuffle=False):
                batch = batch.to(self.device)
                prob = self.explainer_round1(batch)  # [E]
                guided_scores.extend(prob.detach().cpu().numpy().tolist())

        if len(guided_scores) == 0:
            print("[round 2] No guided scores collected; skipping second guided stage.")
            return test_auc

        guided_scores = np.array(guided_scores, dtype=np.float64)
        try:
            second_alpha0 = float(getattr(args, 'second_alpha0', None) or self.pseudo_alpha0)
            second_c = float(getattr(args, 'second_c', None) or self.pseudo_skew_c)
        except Exception:
            second_alpha0 = self.pseudo_alpha0
            second_c = self.pseudo_skew_c
        # Temporarily set for round 2 generation
        self.pseudo_alpha0 = second_alpha0
        self.pseudo_skew_c = second_c
        print(f"[round 2] Quantile params: alpha0={self.pseudo_alpha0}, c={self.pseudo_skew_c}")

        pseudo_labeled_round2 = self._generate_pseudo_labeled_dataset_from_guided(
            train_loader=DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
        )

        round2_loader = DataLoader(pseudo_labeled_round2, batch_size=args.batch_size, shuffle=False)

        if self.dataset == "FC":
            self.explainer_round2 = GuidedExplainerGIN(self.in_dim, hidden_dim=hidden, out_channels=2).to(self.device)
        else:
            self.explainer_round2 = GuidedExplainer(self.in_dim, hidden_dim=hidden, out_channels=2).to(self.device)

        if best_round1_state is not None:
            self.explainer_round2.load_state_dict(best_round1_state)
        else:
            self.explainer_round2.apply(lambda module: reset_parameters(module, self.dataset))

        round2_lr = float(getattr(args, 'second_phase_lr', None) or round1_lr)
        round2_epochs = int(getattr(args, 'second_phase_epochs', None) or round1_epochs)
        w2 = float(getattr(args, 'second_w', None) or args.w)
        optimizer_round2 = torch.optim.Adam(self.explainer_round2.parameters(), lr=round2_lr)

        best_round2_auc = 0.0
        best_round2_state = None
        epochs_no_improve3 = 0

        print(f"[round 2] Training with lr={round2_lr}, epochs={round2_epochs}, w={w2}")
        for epoch in range(round2_epochs):
            self.explainer_round2.train()
            epoch_loss_r2 = 0.0

            for batch3 in round2_loader:
                batch3 = batch3.to(self.device)
                prob3 = self.explainer_round2(batch3)
                labeled_edge_mask3 = (batch3.edge_gt.squeeze() != -1.0)
                if labeled_edge_mask3.sum() == 0:
                    continue

                loss3 = polar_loss(prob3[labeled_edge_mask3], batch3.edge_gt[labeled_edge_mask3], pos_weight=w2)

                optimizer_round2.zero_grad()
                loss3.backward()
                optimizer_round2.step()
                epoch_loss_r2 += loss3.item()

            val_auc_r2 = evaluate_auc_class(val_loader, self.explainer_round2, device=self.device)
            print(f"[round2 | epoch {epoch}/{round2_epochs}] loss={epoch_loss_r2:.4f}, ValAUC={val_auc_r2:.4f}")

            if val_auc_r2 > best_round2_auc:
                best_round2_auc = val_auc_r2
                best_round2_state = {k: v.clone() for k, v in self.explainer_round2.state_dict().items()}
                ckpt_dir = osp.join(save_path, str(args.seed))
                os.makedirs(ckpt_dir, exist_ok=True)
                ckpt_path3 = osp.join(ckpt_dir, f"{dataset}_{args.explainer_name}_best_model_second.pth")
                torch.save(best_round2_state, ckpt_path3)
                print(f"  >> [round2] Best Explainer updated (AUC={best_round2_auc:.4f}) - saved to {ckpt_path3}")
                epochs_no_improve3 = 0
            else:
                epochs_no_improve3 += 1
                if epochs_no_improve3 >= patience:
                    print(f"  >> [round2] Early stopping (no improvement for {patience} epochs).")
                    break

        if best_round2_state is not None:
            self.explainer_round2.load_state_dict(best_round2_state)
            print(f"==> [round2 End] Best restored (Val AUC={best_round2_auc:.4f})")

        print("\n===== AUC score!! =====")
        test_auc_r1 = evaluate_auc_class(test_loader, self.explainer_round1, device=self.device )
        print(f"AUC of test dataset: {test_auc_r1:.4f}")
        test_auc_r2 = evaluate_auc_class(test_loader, self.explainer_round2, device=self.device)
        print(f"AUC of test dataset (second guided): {test_auc_r2:.4f}")

        print("\n===== Bimodality / Binarization!! =====")
        dist_metrics = evaluate_bimod_binar(test_loader, model, self.explainer, self.explainer_round1, 
                                                              self.device, self.dataset, second_label="Round_1")
        dist_metrics_r2 = evaluate_bimod_binar(test_loader, model, self.explainer, self.explainer_round2, 
                                                                 self.device, self.dataset, second_label="Round_2")
        
        print("\n===== F1 score / Fidelity!! =====")
        evaluate_f1(test_loader, model, self.explainer, self.explainer_round1, 
                                     self.device, self.dataset, second_label="Round_1")
        evaluate_f1(test_loader, model, self.explainer, self.explainer_round2, 
                                     self.device, self.dataset, second_label="Round_2")
        
        print("\n===== Visualization!! =====")
        DistributionAnalyzer.compare_base_vs_round1_distributions(
            test_loader, model, self.explainer, self.explainer_round1, self.device, self.dataset,
            args.explainer_name, save_suffix="first")
        DistributionAnalyzer.compare_base_vs_round1_distributions(
            test_loader, model, self.explainer, self.explainer_round2, self.device, self.dataset,
            args.explainer_name, phase_label="Round_2", save_suffix="second", phase_tag="Round_2")
        
        return test_auc_r2
    
    def _generate_pseudo_labeled_dataset(self, train_loader, model):
        # Hyperparameters with safe defaults, can be overridden by attributes if present
        alpha0 = float(getattr(self, 'pseudo_alpha0', 0.10))  # base cutoff in (0, 0.5)
        skew_c = float(getattr(self, 'pseudo_skew_c', 0.50))  # skewness adjustment factor > 0

        pseudo_labeled_data_list = []
        self.explainer.eval()
        model.eval()

        # Pass 1: collect all base mask scores to estimate quantiles and skewness
        all_scores = []
        with torch.no_grad():
            for batch in train_loader:
                batch = batch.to(self.device)
                _, _, node_embeds = model(batch)
                edge_embeds = create_edge_embeds(batch.edge_index, node_embeds).unsqueeze(dim=0)
                sampling_weights = self.explainer(edge_embeds).squeeze(-1)
                mask_scores = sample_graph(sampling_weights, self.device, training=False)
                all_scores.append(mask_scores.detach().cpu())

        if len(all_scores) == 0:
            print("[Pseudo-labeling] No scores collected; returning empty list.")
            return []

        all_scores = torch.cat(all_scores, dim=0).float()
        # Clamp into [0,1] just in case of numerical drift
        all_scores = torch.clamp(all_scores, 0.0, 1.0)

        # Skewness computation (population moments style)
        mu = float(all_scores.mean().item())
        sigma = float(all_scores.std(unbiased=False).item())
        if sigma < 1e-12:
            g1 = 0.0
        else:
            centered = all_scores - mu
            m3 = float(torch.mean(centered ** 3).item())
            g1 = m3 / (sigma ** 3 + 1e-12)

        # Asymmetric quantile levels
        alpha_pos = alpha0 * (1.0 + skew_c * max(g1, 0.0))
        alpha_neg = alpha0 * (1.0 + skew_c * max(-g1, 0.0))
        # Keep within sensible bounds
        alpha_pos = float(min(max(alpha_pos, 1e-4), 0.49))
        alpha_neg = float(min(max(alpha_neg, 1e-4), 0.49))

        # Thresholds via quantiles
        t0 = float(torch.quantile(all_scores, q=alpha_neg).item())
        t1 = float(torch.quantile(all_scores, q=1.0 - alpha_pos).item())
        # Numerical safety
        t0 = max(0.0, min(1.0, t0))
        t1 = max(0.0, min(1.0, t1))

        # Pass 2: assign labels using the computed thresholds
        total_edges = 0
        pseudo_positive = 0
        pseudo_negative = 0
        unlabeled_edges = 0

        with torch.no_grad():
            for batch in train_loader:
                batch = batch.to(self.device)

                _, _, node_embeds = model(batch)
                edge_embeds = create_edge_embeds(batch.edge_index, node_embeds).unsqueeze(dim=0)
                sampling_weights = self.explainer(edge_embeds).squeeze(-1)
                mask_scores = sample_graph(sampling_weights, self.device, training=False)

                # Create pseudo-labels: -1 for unlabeled, 0 for negative, 1 for positive
                pseudo_labels = torch.full_like(mask_scores, -1.0, dtype=torch.float)
                labeled_mask = torch.zeros_like(mask_scores, dtype=torch.bool)

                # Apply quantile thresholds
                positive_mask = mask_scores >= t1
                negative_mask = mask_scores <= t0

                pseudo_labels[positive_mask] = 1.0
                labeled_mask[positive_mask] = True
                pseudo_positive += int(positive_mask.sum().item())

                pseudo_labels[negative_mask] = 0.0
                labeled_mask[negative_mask] = True
                pseudo_negative += int(negative_mask.sum().item())

                total_edges += int(mask_scores.numel())
                unlabeled_edges += int((~labeled_mask).sum().item())

                new_data = Data(
                    x=batch.x.to(self.device),
                    edge_index=batch.edge_index.to(self.device),
                    edge_gt=pseudo_labels.unsqueeze(1).to(self.device),
                    labeled_mask=labeled_mask.to(self.device),
                    y=batch.y.to(self.device) if hasattr(batch, 'y') else None
                )

                pseudo_labeled_data_list.append(new_data)

        return pseudo_labeled_data_list

    def _generate_pseudo_labeled_dataset_from_guided(self, train_loader):
        pseudo_labeled = []
        self.explainer_round1.eval()

        # Hyperparameters (same defaults as base pseudo-labeling)
        alpha0 = float(getattr(self, 'pseudo_alpha0', 0.10))
        skew_c = float(getattr(self, 'pseudo_skew_c', 0.50))

        # Pass 1: collect all guided scores to compute quantiles/skewness
        guided_scores = []
        with torch.no_grad():
            for batch in train_loader:
                batch = batch.to(self.device)
                prob = self.explainer_round1(batch).detach().cpu().float().squeeze()
                guided_scores.append(prob)

        if len(guided_scores) == 0:
            print("[round 2] No guided scores; returning empty list.")
            return []

        guided_scores = torch.cat(guided_scores, dim=0)
        guided_scores = torch.clamp(guided_scores, 0.0, 1.0)

        mu = float(guided_scores.mean().item())
        sigma = float(guided_scores.std(unbiased=False).item())
        if sigma < 1e-12:
            g1 = 0.0
        else:
            centered = guided_scores - mu
            m3 = float(torch.mean(centered ** 3).item())
            g1 = m3 / (sigma ** 3 + 1e-12)

        alpha_pos = alpha0 * (1.0 + skew_c * max(g1, 0.0))
        alpha_neg = alpha0 * (1.0 + skew_c * max(-g1, 0.0))
        alpha_pos = float(min(max(alpha_pos, 1e-4), 0.49))
        alpha_neg = float(min(max(alpha_neg, 1e-4), 0.49))

        t0 = float(torch.quantile(guided_scores, q=alpha_neg).item())
        t1 = float(torch.quantile(guided_scores, q=1.0 - alpha_pos).item())
        t0 = max(0.0, min(1.0, t0))
        t1 = max(0.0, min(1.0, t1))

        total_edges = 0
        pos_edges = 0
        neg_edges = 0
        unlabeled = 0

        with torch.no_grad():
            for batch in train_loader:
                batch = batch.to(self.device)
                prob = self.explainer_round1(batch)  # [E]

                pseudo = torch.full_like(prob, -1.0, dtype=torch.float)
                labeled = torch.zeros_like(prob, dtype=torch.bool)

                pos_mask = prob >= t1
                neg_mask = prob <= t0
                pseudo[pos_mask] = 1.0
                pseudo[neg_mask] = 0.0
                labeled[pos_mask | neg_mask] = True

                total_edges += int(prob.numel())
                pos_edges += int(pos_mask.sum().item())
                neg_edges += int(neg_mask.sum().item())
                unlabeled += int((~(pos_mask | neg_mask)).sum().item())

                new_data = Data(
                    x=batch.x.to(self.device),
                    edge_index=batch.edge_index.to(self.device),
                    edge_gt=pseudo.unsqueeze(1).to(self.device),
                    labeled_mask=labeled.to(self.device),
                    y=batch.y.to(self.device) if hasattr(batch, 'y') else None
                )
                pseudo_labeled.append(new_data)

        return pseudo_labeled
