import os
import random
import time
import argparse
import gc
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch_geometric.data import DataLoader as PyGDataLoader
import json
import pandas as pd
from PIL import Image
from torchvision import transforms

# Import models and training tools
from utils.train_utils import (
    train_vision_model, 
    test_vision_model, 
    train_gnn_model, 
    test_gnn_model, 
    release_resources,
)

from models import (
    ConvNeXtV2Model, 
    ViTModel, 
    SwinModel, 
    ResNetModel, 
    GCNModel, 
    GINModel, 
    GATModel, 
    GPSModel,
    ConvNeXtV2Encoder,
    ViTEncoder,
    SwinEncoder,
    ResNetEncoder,
    GCNEncoder,
    GINEncoder,
    GATEncoder,
    GPSEncoder,
)
from GraphAbstract.topology_classification_generator import *
from GraphAbstract.spectral_gap_generator import *
from GraphAbstract.bridge_generator import (
    BridgeCountDataset,
    generate_bridge_count_dataset,
    generate_bridge_count_image_dataset,
)  
from GraphAbstract.symmetry_generator import(
    SymmetryDataset,
    generate_symmetry_dataset,
    generate_symmetry_image_dataset,
)
from utils.tsne_utils import *

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class GraphImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, indices=None):
        """Image dataset handler
        Args:
            csv_file: Path to CSV file containing labels
            root_dir: Directory containing images
            transform: Image transformations
            indices: List of sample indices
        """
        self.graph_df = pd.read_csv(csv_file)
        
        # Filter specific splits or indices
        if indices is not None:
            self.graph_df = self.graph_df.iloc[indices].reset_index(drop=True)
            
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.graph_df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.graph_df.iloc[idx]['image_path'])
        image = Image.open(img_path).convert('RGB')
        label = self.graph_df.iloc[idx]['label']
        
        if self.transform:
            image = self.transform(image)

        return image, label

# Get image transformations
def get_transforms(image_size=224):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize(int(image_size * 1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, test_transform

def create_experiment_dir(args):
    """Create directories for saving results and models"""
    current_dir = os.getcwd()
    hyperparam_str = f"layers{args.gnn_num_layers}_hidden{args.gnn_hidden_dim}_drop{args.dropout_rate}_layout{args.layout}"

    results_base = os.path.join(current_dir, "results", args.dataset)
    models_base = os.path.join(current_dir, "models", args.dataset)

    exp_dir = os.path.join(results_base, hyperparam_str)
    model_dir = os.path.join(models_base, hyperparam_str)

    for dir_path in [exp_dir, model_dir]:
        os.makedirs(dir_path, exist_ok=True)

    print(f"Experiment Directory: {exp_dir}")
    print(f"Model Directory: {model_dir}")

    return {'exp_dir': exp_dir, 'model_dir': model_dir}


def run_single_model_experiment(dataset_name, layout, dirs, args, device, task_type):
    """Run a single model experiment and evaluate across all difficulty test sets"""
    root_dir = f"./{args.dataset}/{args.layout}"
    
    # Select dataset class and generation function based on task type
    if task_type == 'topology_classification':
        dataset_class = TopologyDataset
        generate_dataset_fun = generate_topology_dataset
        generate_image_dataset_fun = generate_topology_image_dataset
        classification = True
    elif task_type == 'symmetry_classification':
        dataset_class = SymmetryDataset
        generate_dataset_fun = generate_symmetry_dataset
        generate_image_dataset_fun = generate_symmetry_image_dataset
        classification = True
    elif task_type == 'spectral_gap_regression':
        dataset_class = SpectralGapDataset
        generate_dataset_fun = generate_spectral_gap_dataset
        generate_image_dataset_fun = generate_spectral_gap_image_dataset
        classification = False
    elif task_type == 'bridge_count_regression':
        dataset_class = BridgeCountDataset
        generate_dataset_fun = generate_bridge_count_dataset
        generate_image_dataset_fun = generate_bridge_count_image_dataset
        classification = False
    
    # 1. Load or generate all datasets (including test sets of varying difficulty)
    train_dataset, val_dataset, test_datasets = generate_dataset_fun(root_dir, seed=args.seed)
    
    # 2. Generate image representations
    image_dir, dataset_csv = generate_image_dataset_fun(
        root_dir=root_dir,
        layout=layout,
        image_size=args.image_size,
        seed=args.seed
    )
    
    # 3. Load split information
    df = pd.read_csv(dataset_csv)
    
    # 4. Get number of classes / output dimension
    if classification:
        num_classes = len(set(df['label'].tolist()))
        print(f"Dataset contains {num_classes} classes")
        if hasattr(train_dataset, 'class_names'):
                class_names = train_dataset.class_names
        else:
                class_names = [f"Class {i}" for i in range(num_classes)]
    else:
        num_classes = 1  # Regression task
    
    # 5. Model preparation
    if args.model_type == 'gnn':
        # Prepare data loaders for GNN
        train_loader = PyGDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = PyGDataLoader(val_dataset, batch_size=args.batch_size)
        
        # Get input feature dimension
        sample_data = train_dataset[0]
        input_dim = sample_data.x.size(1) if hasattr(sample_data, 'x') and sample_data.x is not None else 1
        
        # Initialize GNN model
        if args.gnn_model_name == 'gcn':
            model = GCNModel(input_dim=input_dim, hidden_dim=args.gnn_hidden_dim,
                           num_classes=num_classes, dropout_rate=args.dropout_rate,
                           num_layers=args.gnn_num_layers)
        elif args.gnn_model_name == 'gin':
            model = GINModel(input_dim=input_dim, hidden_dim=args.gnn_hidden_dim,
                           num_classes=num_classes, dropout_rate=args.dropout_rate,
                           num_layers=args.gnn_num_layers)
        elif args.gnn_model_name == 'gat':
            model = GATModel(input_dim=input_dim, hidden_dim=args.gnn_hidden_dim,
                           num_classes=num_classes, dropout_rate=args.dropout_rate,
                           num_layers=args.gnn_num_layers)
        elif args.gnn_model_name == 'gps':
            model = GPSModel(input_dim=input_dim, hidden_dim=args.gnn_hidden_dim,
                           num_classes=num_classes, dropout_rate=args.dropout_rate,
                           num_layers=args.gnn_num_layers)            
        model = model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=args.gnn_lr, weight_decay=args.gnn_weight_decay)
        
    elif args.model_type == 'vision':  # Vision model
        train_transform, test_transform = get_transforms(args.image_size)
        
        train_indices = df[df['split'] == 'train'].index.tolist()
        val_indices = df[df['split'] == 'val'].index.tolist()
        
        train_dataset = GraphImageDataset(
            csv_file=dataset_csv,
            root_dir=image_dir, 
            transform=train_transform,
            indices=train_indices)
        
        val_dataset = GraphImageDataset(
            csv_file=dataset_csv,
            root_dir=image_dir, 
            transform=test_transform,
            indices=val_indices)
        
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True)
        
        # Initialize vision model
        if args.vision_model_name == 'convnext':
            model = ConvNeXtV2Model(num_classes, model_variant="tiny", dropout_rate=args.dropout_rate)
        elif args.vision_model_name == 'vit':
            model = ViTModel(num_classes, model_variant="base", dropout_rate=args.dropout_rate)
        elif args.vision_model_name == 'swin':
            model = SwinModel(num_classes, dropout_rate=args.dropout_rate)
        elif args.vision_model_name == 'resnet':
            model = ResNetModel(num_classes, dropout_rate=args.dropout_rate)
            
        model = model.to(device)
        # Set optimizer
        if args.vision_model_name == 'resnet':
            optimizer = optim.Adam([
                {'params': [p for n, p in model.model.named_parameters() if 'fc' not in n], 'lr': args.vision_lr},
                {'params': model.model.fc.parameters(), 'lr': args.classifier_lr}
            ], weight_decay=args.vision_weight_decay)
        else:  # convnext, vit, swin
            optimizer = optim.Adam([
                {'params': [p for n, p in model.named_parameters() if 'classifier' not in n], 'lr': args.vision_lr},
                {'params': model.model.classifier.parameters(), 'lr': args.classifier_lr}
            ], weight_decay=args.vision_weight_decay)

    # 6. Train the model
    best_val_metric = 0.0 if classification else float('inf')  # Use accuracy for classification, MAE for regression
    best_model_state = None
    patience_counter = 0
    
    print("Starting training...")
    
    for epoch in range(1, args.epoch + 1):
        if args.model_type == 'gnn':
            train_loss, train_metric = train_gnn_model(model, train_loader, optimizer, device, classification)
            val_metric, _, _ = test_gnn_model(model, val_loader, device, classification)
        elif args.model_type == 'vision':  # vision
            train_loss, train_metric = train_vision_model(model, train_loader, optimizer, device, classification)
            val_metric, _, _ = test_vision_model(model, val_loader, device, classification)
            
        if classification:
            print(f"Epoch {epoch}/{args.epoch}: Train Loss: {train_loss:.4f}, "
                f"Train Acc: {train_metric:.4f}, Val Acc: {val_metric:.4f}")
        else:
            print(f"Epoch {epoch}/{args.epoch}: Train Loss: {train_loss:.4f}, "
                f"Train MAE: {train_metric:.4f}, Val MAE: {val_metric:.4f}")
        
        # Check for improvement
        improved = (classification and val_metric > best_val_metric) or \
                  (not classification and val_metric < best_val_metric)
                  
        if improved:
            best_val_metric = val_metric
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= args.patience:
            print(f"Early stopping: No improvement for {patience_counter} consecutive epochs.")
            break
    
    # Load the best model weights
    model.load_state_dict(best_model_state)
    
    # Save the best model
    model_save_path = os.path.join(dirs['model_dir'], f"{args.model_type}_{args.model_name}.pt")
    torch.save(model.state_dict(), model_save_path)
    print(f"Best model saved to: {model_save_path}")
    
    # 7. Evaluate on all test sets of varying difficulty
    for difficulty, test_dataset in test_datasets.items():
        print(f"\n{'='*20} Evaluating on {difficulty} test set {'='*20}")
        
        if args.model_type == 'gnn':
            test_loader = PyGDataLoader(test_dataset, batch_size=args.batch_size)
            test_metric, test_preds, test_labels = test_gnn_model(model, test_loader, device, classification)
            
        elif args.model_type == 'vision':  
            test_split = f'test_{difficulty}'
            test_indices = df[df['split'] == test_split].index.tolist()
            
            # If no image info found, regeneration is needed
            if not test_indices:
                print(f"Image information for the {difficulty} test set not found, please check if the dataset was correctly generated")
                continue
            
            test_dataset = GraphImageDataset(
                csv_file=dataset_csv,
                root_dir=image_dir,
                transform=test_transform,
                indices=test_indices)
            
            test_loader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True)
            test_metric, test_preds, test_labels = test_vision_model(model, test_loader, device, classification)
        
        # 8. Log results
        results = {
            'dataset': dataset_name,
            'model_type': args.model_type,
            'model_name': args.model_name,
            'hyperparameters': vars(args),
            'task_type': task_type,
            'test_difficulty': difficulty
        }
        
        if classification:
            results['performance'] = {
                'test_accuracy': float(test_metric),
                'best_val_accuracy': float(best_val_metric)
            }
        else:
            results['performance'] = {
                'test_MAE': float(test_metric),
                'best_val_MAE': float(best_val_metric)
            }
        
        # 9. Save results
        result_dir = os.path.join(dirs['exp_dir'], "baseline_results")
        os.makedirs(result_dir, exist_ok=True)
        results_file = os.path.join(result_dir, f"{args.model_type}_{args.model_name}_{difficulty}.json")
        
        all_results = []
        if os.path.exists(results_file):
            try:
                with open(results_file, 'r') as f:
                    all_results = json.load(f)
            except json.decoder.JSONDecodeError:
                all_results = []

        # Append new results
        all_results.append(results)

        # Save all results
        with open(results_file, 'w') as f:
            json.dump(all_results, f, indent=4)

        print(f"\nResults for {difficulty} test set saved to: {results_file}")
        
        # 10. Performance analysis (for classification tasks only)
        if classification:
            analysis_dir = os.path.join(dirs['exp_dir'], f"{args.model_type}_{args.model_name}_{difficulty}_analysis")
            os.makedirs(analysis_dir, exist_ok=True)

            print(f"\nStarting detailed performance analysis for {difficulty} test set...")

            # Compute class-level performance metrics
            class_report, conf_matrix = evaluate_with_class_metrics(test_preds, test_labels, class_names)

            # Log class-wise performance
            results['class_performance'] = class_report

            # Visualize confusion matrix (commented out)
            # visualize_confusion_matrix(
            #     conf_matrix, 
            #     class_names,
            #     f"{args.model_type.capitalize()} {args.model_name.capitalize()} - {difficulty} Confusion Matrix", 
            #     os.path.join(analysis_dir, "confusion_matrix.pdf")
            # )

            # Analyze difficult classes (commented out)
            # difficult_classes = analyze_difficult_classes(class_report, class_names, analysis_dir)
            # print(f"\nF1 score ranking of classes on {difficulty} test set (from low to high):")
            # for class_name, f1_score in difficult_classes:
            #     print(f"{class_name}: {f1_score:.4f}")

            # Feature extraction and t-SNE visualization
            print(f"\nStarting feature extraction and t-SNE visualization for {difficulty} test set...")
            if args.model_type == 'vision':
                features, feature_labels = extract_vision_features(model, test_loader, device)
            elif args.model_type == 'gnn':  # GNN
                features, feature_labels = extract_gnn_features(model, test_loader, device)

            create_tsne_visualization(
                features, 
                feature_labels,
                class_names,
                args.model_name,  
                difficulty, 
                os.path.join(analysis_dir, "tsne_visualization.pdf")
            )

            print(f"\nDetailed analysis for {difficulty} test set saved to: {analysis_dir}")
    
    return results


def main():
    parser = argparse.ArgumentParser()
    # Existing parameters
    parser.add_argument('--dataset', type=str, default='topology_dataset')
    parser.add_argument('--layout', type=str, default='spring')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--patience', type=int, default=10)
    parser.add_argument('--dropout_rate', type=float, default=0.5)
    parser.add_argument(
        '--task_type',
        type=str,
        default='topology_classification',
        choices=[
            'topology_classification',
            'symmetry_classification',
            'spectral_gap_regression',
            'bridge_count_regression',
        ],
    )
    
    # New model selection parameters
    parser.add_argument('--model_type', type=str, required=True, 
                      choices=['vision', 'gnn'], 
                      help='Specify the model type to run: vision, gnn')
    parser.add_argument('--vision_model_name', type=str,
                      choices=['convnext', 'vit', 'swin', 'resnet'],
                      help='Specify the architecture of the vision model')
    parser.add_argument('--gnn_model_name', type=str,
                      choices=['gcn', 'gin', 'gat', 'gps'],
                      help='Specify the architecture of the GNN model')
    parser.add_argument('--test_difficulty', type=str, default=None,
                      choices=['ID', 'Near-OOD', 'Far-OOD'],
                      help='Level of distribution shift for test set (ID=In-Distribution, Near-OOD=Moderate shift, Far-OOD=Significant shift)')
    
    # Other parameters
    parser.add_argument('--gnn_num_layers', type=int, default=3)
    parser.add_argument('--gnn_hidden_dim', type=int, default=128)
    parser.add_argument('--image_size', type=int, default=224)
    parser.add_argument('--vision_lr', type=float, default=5e-6)
    parser.add_argument('--gnn_lr', type=float, default=1e-2)
    parser.add_argument('--vision_weight_decay', type=float, default=1e-4)
    parser.add_argument('--gnn_weight_decay', type=float, default=1e-4)
    parser.add_argument('--fusion_weight_decay', type=float, default=1e-4)
    parser.add_argument('--classifier_lr', type=float, default=1e-3)
    parser.add_argument('--fusion_lr', type=float, default=1e-3)
    parser.add_argument('--fusion_method', type=str, default='concat', 
                        choices=['concat', 'attention', 'weighted'], help='Feature fusion method')
    parser.add_argument('--fusion_hidden_dim', type=int, default=512, help='Hidden dimension for fusion')
    
    args = parser.parse_args()
    
    set_seed(args.seed)
    print(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    dirs = create_experiment_dir(args)

    if args.model_type == 'vision':
        args.model_name = args.vision_model_name
    elif args.model_type == 'gnn':
        args.model_name = args.gnn_model_name
    
    print("="*50)
    print(f"Starting experiment on {args.dataset} using {args.model_type}-{args.model_name}")
    
    if args.test_difficulty:
        print(f"Using {args.test_difficulty} test set")
    print("="*50)
    
    run_single_model_experiment(args.dataset, args.layout, dirs, args, device, args.task_type)

if __name__ == "__main__":
    main()
