import os
import argparse
import gc
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score, cohen_kappa_score
from model.S3Net import S3Net 
from SleepDataLoader import SleepDataLoader


def evaluate_s3net_model():
    parser = argparse.ArgumentParser(description='S3Net evaluation')
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--weights_dir', type=str, default='./results')
    parser.add_argument('--data_path', type=str, default='./data_s3')
    parser.add_argument('--results_dir', type=str, default='./eval_results')
    args = parser.parse_args()

    device = 'cpu' if args.device == 'cpu' else f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    os.makedirs(args.results_dir, exist_ok=True)
    
    batch_size = 32
    fold = 10
    loader = SleepDataLoader(args.data_path)
    
    all_preds = []
    all_trues = []
    fold_scores = []

    for i in range(fold):
        print(f'Evaluating fold {i}...')
        
        model_path = os.path.join(args.weights_dir, f"fold_{i}", "best_model.pth")
        if not os.path.exists(model_path):
            print(f"Model not found: {model_path}")
            continue

        # Load data
        _, _, _, val_x, val_y, val_stft = loader.getFold(i)
        val_x = val_x.to(device)
        val_y = val_y.to(device)
        val_stft = val_stft.to(device)
        
        val_dataset = TensorDataset(val_x, val_y, val_stft)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        # Load model
        model = S3Net(
            num_classes=5, 
            in_chans=10, 
            embed_dim=64, 
            depths=[2, 4, 2],
            num_heads=[2, 2, 2], 
            window_size=7, 
            mlp_ratio=4., 
            qkv_bias=True,
            qk_scale=None, 
            drop_rate=0., 
            attn_drop_rate=0., 
            drop_path_rate=0.1,
            norm_layer=nn.LayerNorm, 
            patch_norm=False, 
            num_experts=2
        ).to(device)
        
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()

        # Evaluate
        preds_fold = []
        trues_fold = []
        with torch.no_grad():
            for inputs, targets, stft in val_loader:
                inputs, targets, stft = inputs.to(device), targets.to(device), stft.to(device)
                logits, _ = model(inputs, stft)
                _, preds = torch.max(logits, dim=1)
                preds_fold.extend(preds.cpu().numpy())
                trues_fold.extend(targets.argmax(dim=1).cpu().numpy())

        acc = accuracy_score(trues_fold, preds_fold)
        print(f'Fold {i} accuracy: {acc:.4f}')
        fold_scores.append(acc)
        
        all_preds.extend(preds_fold)
        all_trues.extend(trues_fold)

        # Clean up
        del model, val_loader
        torch.cuda.empty_cache()
        gc.collect()

    if not fold_scores:
        print("No valid model weights found!")
        return

    # Calculate metrics
    classes = ['W', 'N1', 'N2', 'N3', 'REM']
    overall_acc = accuracy_score(all_trues, all_preds)
    f1_macro = f1_score(all_trues, all_preds, average='macro')
    kappa = cohen_kappa_score(all_trues, all_preds)
    
    # Print results
    print('='*60)
    print('EVALUATION RESULTS')
    print('='*60)
    print(f'Fold accuracies: {[f"{acc:.4f}" for acc in fold_scores]}')
    print(f'Mean accuracy: {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}')
    print(f'Overall accuracy: {overall_acc:.4f}')
    print(f'F1-macro: {f1_macro:.4f}')
    print(f'Cohen\'s Kappa: {kappa:.4f}')
    
    # Per-class F1 scores
    f1_per_class = f1_score(all_trues, all_preds, average=None)
    print('\nPer-class F1 scores:')
    for name, score in zip(classes, f1_per_class):
        print(f'  {name}: {score:.4f}')
    
    # Classification report
    print('\nClassification Report:')
    print(classification_report(all_trues, all_preds, target_names=classes, digits=4))
    
    # Confusion matrix
    cm = confusion_matrix(all_trues, all_preds)
    print('\nConfusion Matrix:')
    print('     ', '  '.join(f'{cls:>4}' for cls in classes))
    for i, row in enumerate(cm):
        print(f'{classes[i]:>4}', '  '.join(f'{val:>4}' for val in row))
    
    # Save results to file
    results_file = os.path.join(args.results_dir, 'results.txt')
    with open(results_file, 'w') as f:
        f.write('EVALUATION RESULTS\n')
        f.write('='*60 + '\n')
        f.write(f'Overall accuracy: {overall_acc:.4f}\n')
        f.write(f'F1-macro: {f1_macro:.4f}\n')
        f.write(f'Cohen\'s Kappa: {kappa:.4f}\n\n')
        f.write('Per-class F1 scores:\n')
        for name, score in zip(classes, f1_per_class):
            f.write(f'  {name}: {score:.4f}\n')
        f.write('\nClassification Report:\n')
        f.write(classification_report(all_trues, all_preds, target_names=classes, digits=4))
        f.write('\nConfusion Matrix:\n')
        f.write(str(cm))
    
    print(f'\nResults saved to: {results_file}')


if __name__ == '__main__':
    evaluate_s3net_model()