import os
import time
import glob
import pdb
import numpy as np
import torch.utils.data as data
import utils
import torch
import torch.nn as nn
from torch.nn.modules.module import Module
from sklearn.metrics import roc_curve,auc,precision_recall_curve, precision_score, recall_score

def wait_for_files(folder, step, num_files, timeout=10800):  # 3 hours timeout
    start_time = time.time()
    expected_suffixes = (f"_{step}.pt", f"_{step}.npy")

    print("Waiting for files")
    print(step)
    print(folder)
    
    while True:
        os.system("sync")  
        
        # Get all files with the expected suffix
        files = [f for f in os.listdir(folder) if f.endswith(expected_suffixes)]
        print(files)
        
        if len(files) >= num_files:
            print(f"Found {len(files)} files for step {step}. Proceeding...")
            return
        
        if time.time() - start_time > timeout:
            print(f"Required number of files ({num_files}) were not available within {timeout} seconds.")
            exit(1)
        
        time.sleep(30)  



def load_and_calculate_std(ensemble_path, step, abn_idx):

    pattern = os.path.join(ensemble_path, f"*_{step}.pt")
    files = glob.glob(pattern)

    if not files:
        print(f"No matching files found for step {step} in {ensemble_path}")
        return None

    print(f"Found {len(files)} files for standard deviation calculation at step {step}")

    tensors = [torch.load(f, map_location="cuda:0") for f in files]

    stacked_tensors = torch.stack(tensors, dim=0)  # Shape: (num_models, tensor_size)

    std_en = torch.std(stacked_tensors, dim=0, unbiased=False)  # Standard deviation
    # std_en = std_en.reshape(950, 200)
    # std_en = std_en.reshape(880, 200)
    std_en = std_en.reshape(abn_idx, 200)
    return std_en

def evaluate_prediction(net, config, train_loader, path, seed, run, step, model_file=None):
    with torch.no_grad():
        net.eval()
        net.flag = "Test"

        if model_file is not None:
            net.load_state_dict(torch.load(model_file, map_location='cuda:0'))

        # Ensure the path directory exists
        os.makedirs(path, exist_ok=True)

        load_iter = iter(train_loader)
        all_predictions = []  # Store predictions as tensors

        temp_predict = torch.zeros((0)).cuda()
        for i in range(len(train_loader.dataset)):
            _data = next(load_iter)[0]
            # _data, _label, index = next(load_iter)


            _data = _data.cuda()
            res = net(_data)   
            a_predict = res["frame"]

            temp_predict = torch.cat([temp_predict, a_predict], dim=0)

            if (i + 1) % 10 == 0:
                a_predict = temp_predict.mean(0)  # Keep in tensor format
                all_predictions.append(a_predict)  
                temp_predict = torch.zeros((0)).cuda()

        frame_predict = torch.cat(all_predictions)  # Keep as tensor

        # Save directly in the specified path directory
        output_path = os.path.join(path, f"{run}_{seed}_{step}.pt")
        torch.save(frame_predict, output_path)

        print(f"Predictions saved at {output_path}")

def calculate_mean(ensemble_path, step, wandb_viz):

    pattern = os.path.join(ensemble_path, f"*_{step}.npy")
    files = glob.glob(pattern)

    if not files:
        print(f"No matching files found for step {step} in {ensemble_path}")
        return None

    print(f"Found {len(files)} files for standard deviation calculation at step {step}")

    arrays = [np.load(f) for f in files]  # each array should be shape (N,)

    # Stack into shape (num_models, N)
    stacked_arrays = np.stack(arrays, axis=0)  # shape: (num_models, N)

    # Compute mean across models (axis=0)
    avg_predict = np.mean(stacked_arrays, axis=0)

    # Step 2: Load ground truth
    frame_gt = np.load('frame_label/gt-ucf.npy')

    # Step 3: Compute AP
    precision_curve, recall_curve, _ = precision_recall_curve(frame_gt, avg_predict)
    ap_score = auc(recall_curve, precision_curve)
    print(f"AP: {ap_score:.4f}")

    # Step 4: Compute precision & recall at thresholds
    thresholds = np.arange(0.1, 1.0, 0.1)
    log_dict_ = {"En_AP": ap_score}

    print("Threshold\tPrecision\tRecall")
    for thresh in thresholds:
        y_pred = (avg_predict >= thresh).astype(int)
        precision = precision_score(frame_gt, y_pred, zero_division=0)
        recall = recall_score(frame_gt, y_pred, zero_division=0)
        print(f"{thresh:.1f}\t\t{precision:.4f}\t\t{recall:.4f}")
        log_dict_[f'EN_precision_{thresh:.1f}'] = precision
        log_dict_[f'EN_recall_{thresh:.1f}'] = recall

    wandb_viz.run.log(log_dict_)


def evaluate_test(net, config, test_loader, path, seed, run, step, model_file=None):
    with torch.no_grad():
        net.eval()
        net.flag = "Test"

        if model_file is not None:
            net.load_state_dict(torch.load(model_file, map_location='cuda:0'))

        # Ensure the path directory exists
        os.makedirs(path, exist_ok=True)

        load_iter = iter(test_loader)
        frame_gt = np.load("frame_label/gt-ucf.npy")
        frame_predict = None

        temp_predict = torch.zeros((0)).cuda()
        
        for i in range(len(test_loader.dataset)):
            
            

            _data, _label, _name = next(load_iter)
            
            _data = _data.cuda()
            _label = _label.cuda()
            
            res = net(_data)   
            a_predict = res["frame"]
            temp_predict = torch.cat([temp_predict, a_predict], dim=0)
            if (i + 1) % 10 == 0 :
                a_predict = temp_predict.mean(0).cpu().numpy()
                
                fpre_ = np.repeat(a_predict, 16)
                if frame_predict is None:         
                    frame_predict = fpre_
                else:
                    frame_predict = np.concatenate([frame_predict, fpre_])  
                temp_predict = torch.zeros((0)).cuda()



        precision, recall, th = precision_recall_curve(frame_gt, frame_predict)
        ap_score = auc(recall, precision)

        # Save directly in the specified path directory
        output_path = os.path.join(path, f"{run}_{seed}_{step}.npy")
        np.save(output_path, frame_predict)

        print(f"Predictions saved at {output_path}")
