import torch
import torch.nn as nn
import torch.nn.functional as F
from Utils.utils import load_dataset
from torch_geometric.utils import k_hop_subgraph
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

from models import GNN, MLP

import json
import math
from sklearn.metrics import accuracy_score, f1_score, balanced_accuracy_score
import numpy as np
import os
import csv
import time

from Utils.concepts import get_concepts
from Utils.gta import *
from Utils.ib import *

device = 'cuda'

def main(dataset, model_type, top_k, beta, train_ratio, num_hop, num_mlp_layer, up_ratio=None, perturb_ratio=None):
        
    data = load_dataset(dataset, train_ratio, up_ratio, perturb_ratio).to(device)
    if isinstance(data.y, torch.Tensor):
        n = data.y.size()[0]
    else:
        n = len(data.y)
    num_class = int(max(data.y))+1

    train_mask, valid_mask, test_mask = data.train_mask, data.valid_mask, data.test_mask
    
    # LOAD CONCEPT CANDIDATES
    concepts = get_concepts(dataset)
    
    # LOAD GRAPH-CONCEPT MAPPING FUNCTION
    graph_emb = graph_embedding(dataset, data, model_type=model_type, device=device)       
    concept_emb = concept_embedding(concepts).to(device)
    target_similarity = cosine_similarity(graph_emb, concept_emb, scaled=False)
    
    # TRAINING/TEST DATA PREPARE
    x_tensor = target_similarity.to(device).float()
    x_tensor = x_tensor.detach()

    if not isinstance(data.y, torch.Tensor):
        data.y = torch.tensor(data.y)
    y_tensor = data.y.to(device)
    x_tensor_train, x_tensor_valid, x_tensor_test = x_tensor[train_mask], x_tensor[valid_mask], x_tensor[test_mask]
    y_tensor_train, y_tensor_valid, y_tensor_test = y_tensor[train_mask], y_tensor[valid_mask], y_tensor[test_mask]

    input_size = x_tensor.size()[1]
    # Create a custom dataset
    class CustomDataset(Dataset):
        def __init__(self, x, y):
            self.x = x
            self.y = y
    
        def __len__(self):
            return len(self.x)
    
        def __getitem__(self, idx):
            return self.x[idx], self.y[idx]
    
    # Create Dataset and DataLoader
    batch_size = x_tensor_train.size()[0]
    train_dataset = CustomDataset(x_tensor_train, y_tensor_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    def accuracy(out, y):
        pred = torch.argmax(out, dim=1)
        y_true = y.cpu()
        y_pred = pred.cpu()
    
        acc = accuracy_score(y_true, y_pred)
        macro_f1 = f1_score(y_true, y_pred, average='macro')
        bacc = balanced_accuracy_score(y_true, y_pred)
    
        return acc, macro_f1, bacc

            
    class FixedGateClassifier(nn.Module):
        def __init__(self, mask, in_channels, out_channels):
            super().__init__()
            self.mask = mask.detach()
            self.model = MLP(in_channels=in_channels, out_channels=out_channels, num_layers=num_mlp_layer, dropout=dropout)
    
        def forward(self, C):
            C_selected = C * self.mask  # Apply mask
            return self.model(C_selected[:, self.mask.bool()])

    class_counts = torch.bincount(y_tensor_train, minlength=num_class).float()
    
    # Identify present classes
    present_classes = class_counts > 0
    
    # Initialize weights as zeros
    class_weight = torch.zeros_like(class_counts)
    
    # Assign inverse-frequency weights only to present classes
    class_weight[present_classes] = 1.0 / class_counts[present_classes]
    
    # Normalize so the present class weights sum to num_present (optional)
    num_present = present_classes.sum().item()
    class_weight[present_classes] = class_weight[present_classes] /class_weight[present_classes].sum() * num_present

        
    test_acc_list = []
    test_f1_list = []
    test_bacc_list = []
    for run in range(5):
        # Step 1: Train CBM + IB model with learnable gates
        classifier = MLP(in_channels=concept_emb.size(0), out_channels=num_class, num_layers=num_mlp_layer, dropout=0.2).to(device)
        model = ConceptBottleneckIB(concept_emb, classifier=classifier, num_class=num_class, beta=beta).to(device)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)
        best_gates = None
        print ("# Training phase 1")
        # Training phase 1
        min_loss = math.inf
        for epoch in range(400):
            # print (epoch)
            model.train()
            total_loss = 0
            for C, y in train_loader:
                optimizer.zero_grad()
                loss, _, _ = model(C, y, class_weight)
                loss.backward(retain_graph=True)
                optimizer.step()
                total_loss += loss.item()
        
            model.eval()
            loss_train, logits_train, gates_train = model(x_tensor_train, y_tensor_train)
            loss_valid, logits_valid, gates_valid = model(x_tensor_valid, y_tensor_valid)
        
            if loss_valid < min_loss:
                min_loss = loss_valid
                best_gates = gates_valid 
                no_improve = 0
            else: no_improve +=1
            if no_improve > 10:
                break
        
        # Step 2: Extract top-k gates
        _, topk_indices = torch.topk(best_gates, k=top_k)
        fixed_mask = torch.zeros_like(best_gates)
        fixed_mask[topk_indices] = 1.0
        
        # Step 3: Retrain classifier on fixed concept subset
        fixed_classifier = FixedGateClassifier(mask=fixed_mask, in_channels=top_k, out_channels=num_class).to(device)
        optimizer_fixed = torch.optim.Adam(fixed_classifier.parameters(), lr=0.01)

        # Retraining only the classifier
        min_loss = math.inf
        save_dir = "./cache"
        os.makedirs(save_dir, exist_ok=True)
        
        best_model_path = f"{save_dir}/{dataset}_best_model_{time.time():.0f}.pth"
        for epoch in range(400):
            fixed_classifier.train()
            for C, y in train_loader:
                optimizer_fixed.zero_grad()
                logits = fixed_classifier(C)
                loss = F.cross_entropy(logits, y, weight=class_weight)
                loss.backward(retain_graph=True)
                optimizer_fixed.step()
                
            fixed_classifier.eval()
            valid_output = fixed_classifier(x_tensor_valid)
            valid_loss = F.cross_entropy(valid_output, y_tensor_valid, weight=class_weight)
            valid_acc = accuracy(valid_output, y_tensor_valid)
            
            if valid_loss < min_loss:
                torch.save(fixed_classifier.state_dict(), best_model_path)
                best_valid_loss = valid_loss
                no_improve = 0
            else: no_improve +=1
            if no_improve > 10:
                break
        
        # Final evaluation
        fixed_classifier.eval()
        fixed_classifier.load_state_dict(torch.load(best_model_path))
        with torch.no_grad():
            test_acc, test_f1, test_bacc = accuracy(fixed_classifier(x_tensor_test), y_tensor_test)
        test_acc_list.append(test_acc*100)
        test_f1_list.append(test_f1*100)
        test_bacc_list.append(test_bacc*100)
    
    acc_mean, acc_std = np.mean(test_acc_list), np.std(test_acc_list)
    f1_mean, f1_std = np.mean(test_f1_list), np.std(test_f1_list)
    bacc_mean, bacc_std = np.mean(test_bacc_list), np.std(test_bacc_list)
    print ('test acc: {:.4f} += {:.4f}'.format(acc_mean, acc_std),
          'test f1: {:.4f} += {:.4f}'.format(f1_mean, f1_std),
          'test bacc: {:.4f} += {:.4f}'.format(bacc_mean, bacc_std))
    

if __name__ == "__main__":

    import argparse

    parser = argparse.ArgumentParser(description="GNN Experiment Args")
    parser.add_argument("--dataset", type=str, help="Dataset name", default='cora')
    parser.add_argument("--device", help="device", default='cuda')
    parser.add_argument("--model_type", help="Pretrained graph encoder type",
                        choices=['gcn', 'gt', 'gat'], default='gcn')
    parser.add_argument("--train_ratio", type=int, help="Percentage of training data",
                        choices=[1, 2, 5, 10, 20, 30], default=20)
    parser.add_argument("--top_k", type=int, help="Number of selected concepts",
                        choices=[30, 50, 70, 90], default=50)
    parser.add_argument("--beta", type=float, help="Beta coefficient", default=0.0005)
    parser.add_argument("--dropout", type=float, help="Dropout rate", default=0.3)
    parser.add_argument("--hidden", type=int, help="Hidden size of GNN encoder", default=64)
    parser.add_argument("--head", type=int, help="Number of attention heads of GNN encoder", default=4)
    parser.add_argument("--hop", type=int, help="Number of hops of GNN encoder", default=4)
    parser.add_argument("--num_mlp_layer", type=int, help="num_mlp_layer", default=2)

    # Only required for adversarial setting
    parser.add_argument("--perturb_ratio", type=float,
                        help="Edge perturbation ratio",
                        choices=[0.05, 0.1, 0.2, 0.3, 0.5], default=None)

    # Only required for OOD setting (renamed to avoid name clash)
    parser.add_argument("--up_ratio", type=int,
                        help="Upsampling ratio for OOD setting",
                        choices=[2, 3, 5, 10], default=None)

    args = parser.parse_args()

    # Optional: assign and print all args
    for key, value in vars(args).items():
        print(f"{key}: {value}")
        
    dataset, model_type, train_ratio, top_k = args.dataset, args.model_type, args.train_ratio, args.top_k
    beta, dropout, hidden, head, hop = args.beta, args.dropout, args.hidden, args.head, args.hop
    num_mlp_layer = args.num_mlp_layer
    perturb_ratio, up_ratio = args.perturb_ratio, args.up_ratio    

    device = torch.device(args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))

    main(dataset, model_type, top_k, beta, train_ratio, hop, num_mlp_layer, up_ratio, perturb_ratio)
