import torch
from model.rgvt import init_graph_view_transformation, recurrent_gvt, Predictor
from model.gvt import GraphViewTransformation
from utils.training import PredictorTrainer, TrainerConfig
from load_dataset import load_dataset

class HyperparameterTuner:

    def __init__(self, max_hops, config, device):
        self.max_hops = max_hops
        self.config = config
        self.device = device
        self.trainer = PredictorTrainer(config)

    def tune_aggregator_method_combined(self, features, edge_index, labels, masks, args, A_rw=None, A_sym=None):

        linear_results = {}
        pretrained_state = torch.load(args.checkpoint, map_location=self.device)

        best_linear_val_score = float('-inf')
        num_classes = labels.max().item() + 1
        is_mlp = args.predictor_type == 'mlp'

        for h in range(1, self.max_hops + 1):
            try:
                print(f"\n--- Processing h={h} ---")
                
                # Load fresh ViewTransformer model for each h (from in-memory state)
                view_transformer = init_graph_view_transformation(args, device=self.device)
                view_transformer.load_state_dict(pretrained_state)

                # STEP 1: Linear Probing (Aggregator frozen)
                print(f"Step 1: Linear Probing for h={h}")
                view_transformer.eval()
                # Freeze parameters
                for param in view_transformer.parameters():
                    param.requires_grad = False
                
                # Generate features using K=h iterations
                x_agg = recurrent_gvt(
                    view_transformer, features, edge_index, args, manual_depth=h,
                    training=False, A_rw=A_rw, A_sym=A_sym
                )

                # Train predictor only
                predictor = Predictor(x_agg.shape[1], num_classes, bias=True, is_mlp=is_mlp).to(self.device)

                best_linear_val, best_linear_test, _ = self.trainer.train(
                    predictor, x_agg, labels, masks, self.device,
                    desc=f"LinearProbe h={h}"
                )

                linear_results[h] = {'val': best_linear_val, 'test': best_linear_test}

                if best_linear_val > best_linear_val_score:
                    best_linear_val_score = best_linear_val

                # Summary for this hop
                print(f"--- h={h} Summary ---")
                print(f"  Linear Probing: Val {100 * best_linear_val:.2f}%, Test {100 * best_linear_test:.2f}%")

                # Clean up
                del view_transformer, predictor, x_agg
                if torch.cuda.is_available():
                    print(f"Memory usage after h={h}: {torch.cuda.memory_allocated(self.device) / 1e6:.2f} MB")
            
            except Exception as e:
                print(f"An error occurred during h={h}: {e}")
                print("Setting results to 0.0 and continuing...")
                linear_results[h] = {'val': 0.0, 'test': 0.0}
                continue

        return linear_results

def evaluate_dataset_stage(dataset_name, args, device='cpu', results_dir=None):
    
    print(f"\n=== Evaluating Dataset: {dataset_name} ===")

    max_hops = args.max_depth
    L = args.adj_max_hop
    
    # Load dataset
    g = load_dataset(dataset_name, split_index=0)
    g = g.int().to(device)
    features = g.ndata["feat"] if "feat_norm" not in g.ndata else g.ndata["feat_norm"]
    labels = g.ndata["label"]
    masks = [g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]]
    edge_index = torch.stack(g.edges(), dim=0).long()
    
    # Visualization removed: skip feature style analysis
    
    if torch.cuda.is_available():
        print(f"Memory usage before tuning: {torch.cuda.memory_allocated(device) / 1e6:.2f} MB")

    # Initialize hyperparameter tuner
    config = TrainerConfig()
    config.lr = args.learning_rate
    tuner = HyperparameterTuner(max_hops, config, device)

    # Pre-compute adjacency once and reuse
    A_rw, A_sym = GraphViewTransformation._build_adjacency_matrices(edge_index, features.size(0))

    agg_results = tuner.tune_aggregator_method_combined(
        features, edge_index, labels, masks, args, A_rw=A_rw, A_sym=A_sym
    )

    # Select best ViewTransformer hyperparameters
    if agg_results:
        best_agg_h = max(agg_results.keys(), key=lambda x: agg_results[x]['val'])
        best_agg_val = agg_results[best_agg_h]['val']
        best_agg_test = agg_results[best_agg_h]['test']
    else:
        best_agg_h = None
        best_agg_val = float('nan')
        best_agg_test = float('nan')
    
    # 5. Print results summary
    print_results_summary(dataset_name, agg_results, best_agg_h, results_dir)
    
    return {
        'Agg_best_val': best_agg_val,
        'Agg_best_test': best_agg_test,
        'best_agg_h': best_agg_h,
    }


def print_results_summary(dataset_name, agg_results, best_agg_h, results_dir=None):

    import os
    import datetime
    
    # Use provided results_dir or create a new timestamped one
    if results_dir is None:
        timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        results_dir = f'results/{timestamp}'
    
    # Create results directory if it doesn't exist
    os.makedirs(results_dir, exist_ok=True)
    
    # Generate filename in the shared results directory
    filename = f'{results_dir}/{dataset_name}_detailed.txt'
    
    # Collect all output in a list
    output_lines = []
    
    # Header
    output_lines.append(f"{'='*80}")
    output_lines.append(f"DETAILED RESULTS SUMMARY FOR {dataset_name.upper()}")
    output_lines.append(f"Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    output_lines.append(f"{'='*80}")
    
    # Best results summary table
    output_lines.append(f"\n🎯 BEST RESULTS COMPARISON")
    output_lines.append(f"{'='*60}")
    
    if agg_results and best_agg_h is not None:
        agg_best_val = agg_results[best_agg_h]['val']
        agg_best_test = agg_results[best_agg_h]['test']

        output_lines.append(f"📊 LINEAR PROBING:")
        output_lines.append(f"  ViewTransformer (h={best_agg_h}): Val {agg_best_val*100:6.2f}% | Test {agg_best_test*100:6.2f}%")
        
    # Detailed hyperparameter tuning results
    output_lines.append(f"\n\n📈 DETAILED HYPERPARAMETER TUNING RESULTS")
    output_lines.append(f"{'='*60}")
    
    # ViewTransformer Linear Probing results
    if agg_results:
        output_lines.append(f"\n[VIEWTRANSFORMER LINEAR PROBING] Hyperparameter Tuning:")
        output_lines.append(f"{'h':<3} | {'Val%':<7} | {'Test%':<7}")
        output_lines.append(f"{'-'*3} | {'-'*7} | {'-'*7}")
        for h, res in agg_results.items():
            marker = " ⭐" if h == best_agg_h else "   "
            output_lines.append(f"{h:<3} | {100 * res['val']:<7.2f} | {100 * res['test']:<7.2f}{marker}")
    else:
        output_lines.append(f"\n[VIEWTRANSFORMER LINEAR PROBING] No valid ViewTransformer models found!")
    
    output_lines.append(f"\n{'='*80}")
    
    # Write all output to file
    with open(filename, 'w') as f:
        f.write('\n'.join(output_lines))
    
    # Also print a concise summary to console
    print(f"\n🎯 QUICK SUMMARY FOR {dataset_name}:")
    if agg_results and best_agg_h is not None:
        agg_best = agg_results[best_agg_h]
        print(f"   LinearProbe: Val {agg_best['val']*100:6.2f}% | Test {agg_best['test']*100:6.2f}%")
    
    print(f"📄 Detailed results saved to: {filename}")
