import pdb
import numpy as np
import torch.utils.data as data
import utils
from options import *
from config import *
from train_en import *
from xd_test_en import *
from model import *
from utils import WandbVisualizer

from dataset_loader_en import *
from tqdm import tqdm

if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()

    config = Config(args)
    worker_init_fn = None
   

    config.len_feature = 1024

    gpus = [0, 0, 0]  # Specify the GPUs you want to use
    seeds = [725, 45, 234]
    nets = []
    for i in range(3):
        seed = seeds[i]
        utils.set_seed(seed)
        net = WSAD(config.len_feature, flag="Train", a_nums=60, n_nums=60)
        net = net.to(f"cuda:{gpus[i]}")
        nets.append(net)

        print(f"Model {i} initialized on GPU {gpus[i]} with seed {seed}")

    normal_train_loader = data.DataLoader(
        XDVideo(root_dir = config.root_dir, mode = 'Train',modal = config.modal, num_segments = 200, len_feature = config.len_feature, is_normal = True, is_cluster = config.is_cluster, cluster_file = config.cluster_file, is_uncertainty = config.is_uncertainty),
            batch_size = 64,
            shuffle = True, num_workers = config.num_workers,
            worker_init_fn = worker_init_fn, drop_last = True)
    abnormal_train_loader = data.DataLoader(
        XDVideo(root_dir = config.root_dir, mode='Train', modal = config.modal, num_segments = 200, len_feature = config.len_feature, is_normal = False, is_cluster = config.is_cluster, cluster_file = config.cluster_file, is_uncertainty = config.is_uncertainty),
            batch_size = 64,
            shuffle = True, num_workers = config.num_workers,
            worker_init_fn = worker_init_fn, drop_last = True)
    test_loader = data.DataLoader(
        XDVideo(root_dir = config.root_dir, mode = 'Test', modal = config.modal, num_segments = config.num_segments, len_feature = config.len_feature),
            batch_size = 5,
            shuffle = False, num_workers = config.num_workers,
            worker_init_fn = worker_init_fn)

    test_info = {"step": [], "auc": [],"ap":[]}
    
    best_auc = 0

    criterions = [AD_Loss() for _ in range(3)]
    optimizers = [
        torch.optim.Adam(nets[i].parameters(), lr=config.lr[0],
                        betas=(0.9, 0.999), weight_decay=0.00005)
        for i in range(3)
    ]
    
    wandb_viz = WandbVisualizer(config=config, project="xd_aaai", run_name=config.run)
    wandb_viz.run.log_code()    
    frame_gt = np.load("frame_label/xd_gt.npy")

    num_videos = len(abnormal_train_loader.dataset.vid_list)
    num_segments = 200

    uncertainty_tensor = torch.zeros(num_videos, num_segments)


    for step in tqdm(
            range(1, config.num_iters + 1),
            total = config.num_iters,
            dynamic_ncols = True
        ):
        if step > 1 and config.lr[step - 1] != config.lr[step - 2]:
            for optimizer in optimizers:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = config.lr[step - 1]

        if (step - 1) % len(normal_train_loader) == 0:
            normal_loader_iter = iter(normal_train_loader)

        if (step - 1) % len(abnormal_train_loader) == 0:
            abnormal_loader_iter = iter(abnormal_train_loader)

            if config.is_uncertainty and step > 1:
                abnormal_train_loader.dataset.update_uncertainty(uncertainty_tensor)


        normal_batch = next(normal_loader_iter)
        abnormal_batch = next(abnormal_loader_iter)
        video_ids = abnormal_batch[2] 
        for i in range(3):  
            device = f"cuda:{gpus[i]}"
            train(nets[i], normal_batch, abnormal_batch, optimizers[i], criterions[i], wandb_viz, step, config.is_cluster, config.is_uncertainty, config.tau_a, config.un_th, config.alpha)
            with torch.no_grad():
                preds_list = []
                ainput = abnormal_batch[0]
                video_ids = abnormal_batch[2]  # indices or video ids
                
                for i in range(3):
                    net = nets[i]
                    net.eval()
                    ainput_device = ainput.to(next(net.parameters()).device)
                    res = net(ainput_device)
                    pred = res["frame"]  # shape: [B, T] or [B, 1, T]
                    preds_list.append(pred.cpu().numpy())
                    net.train()
                
                preds = np.stack(preds_list, axis=0)  # [3, B, T]
                std_pred = np.std(preds, axis=0)

        for b in range(len(video_ids)):
            vid = video_ids[b].item()  
            segment_uncertainty = std_pred[b]  
            uncertainty_tensor[vid] = torch.from_numpy(segment_uncertainty).float()

        if step % 100 == 0 and step > 10:
            all_preds = []
            all_metrics = []
            for i in range(3):
                nets[i].eval()
                preds = test(nets[i], config, wandb_viz, test_loader, test_info=None, step=step)
                all_preds.append(preds)
                nets[i].train()

            # Average predictions over models, assuming preds are numpy arrays or tensors
            avg_preds = np.mean(np.stack(all_preds, axis=0), axis=0)  # shape: [num_samples, ...]

            fpr,tpr,_ = roc_curve(frame_gt, avg_preds)
            auc_score = auc(fpr, tpr)
        
            precision, recall, th = precision_recall_curve(frame_gt, avg_preds,)
            ap_score = auc(recall, precision)       


            # Update test_info or your logging with avg_metrics
            test_info["step"].append(step)
            test_info["auc"].append(auc_score)
            test_info["ap"].append(ap_score)
            log_dict = {
            'roc_auc': auc_score,
            'pr_auc': ap_score,
            }

            thresholds = np.arange(0.1, 1.0, 0.1)

            for thresh in thresholds:
                y_pred = (avg_preds >= thresh).astype(int)
                precision = precision_score(frame_gt, y_pred, zero_division=0)
                recall = recall_score(frame_gt, y_pred, zero_division=0)
                log_dict[f'precision_{thresh:.1f}'] = precision
                log_dict[f'recall_{thresh:.1f}'] = recall


            wandb_viz.run.log(log_dict)

            # Save best models based on averaged metrics
            if step >= 500 and ap_score > best_auc:
                best_auc = ap_score
                utils.save_best_record(test_info, 
                    os.path.join(config.output_path, f"xd_best_auc_{config.run}_{config.seed}.txt"))
                for i in range(3):
                    torch.save(nets[i].state_dict(), os.path.join(args.model_path, f"{config.run}_{i}_{config.seed}.pkl"))
                
                utils.save_best_record(test_info, 
                    os.path.join(config.output_path, f"xd_best_ap_{config.run}_{config.seed}.txt"))

        if step == config.num_iters:
            for i in range(3):
                torch.save(nets[i].state_dict(), os.path.join(args.model_path, f"xd_{config.run}_model{i}_{step}.pkl"))



   