
import logging
import re
import warnings
from torch.utils.data import DataLoader, Subset
import wandb
import torch.nn.functional as F
from cdc.utils.torch_clustering import PyTorchKMeans
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize


logging.captureWarnings(True)
warnings.filterwarnings('always', category=DeprecationWarning,
                        module=r'^{0}.'.format(re.escape(__name__)))
warnings.warn("This is a DeprecationWarning",category=DeprecationWarning)


import argparse
import os
import torch
from cdc.args import parse_cfg, get_model, get_strong_transformations,\
    get_val_transformations, get_standard_transformations,\
    get_train_dataloader, get_val_dataloader,\
    get_train_dataset,get_val_dataset, get_optimizer\
    #get_criterion, adjust_learning_rate
from cdc.utils.evaluate_utils import get_predictions, \
    hungarian_evaluate, hign_conf_evaluation, hungarian_evaluate_hard
from cdc.methods.calibrate_train import orth_train
from cdc.methods.cc_train import cc_train, cc_train_sample
from cdc.methods.scan_train import init_head_doublelayer_bias
from cdc.methods.dyn_train import SampleMasterTracker

FLAGS = argparse.ArgumentParser(description='CDC Model')
FLAGS.add_argument('--config_env', default='scripts/cc/env.yaml', help='Location of path config file')
FLAGS.add_argument('--config_exp', default='scripts/cc/cifar10/cc_lr00001_ini_bias_a10_sample_stabilityloss_t1_seed5.yaml',help='Location of experiments config file')
FLAGS.add_argument('--seed', default=5, type=int)

os.environ["WANDB_API_KEY"] = '2a4485eff00bb9efe7db48f5ca413f10466663b4'
os.environ["WANDB_MODE"]="offline"

def spherical_kmeans_torch(X, n_clusters, n_iters=10):
    """
    Spherical KMeans (cosine similarity), compatible with GPU.
    Args:
        X: [N, D] input features, should be L2 normalized.
        n_clusters: number of clusters
        n_iters: iterations
    Returns:
        cluster_labels: [N]
        centroids: [K, D]
    """
    X = F.normalize(X, dim=1)
    N, D = X.shape
    # Step 1: random initialization
    indices = torch.randperm(N)[:n_clusters]
    centroids = X[indices]  # [K, D]

    for _ in range(n_iters):
        sim = torch.matmul(X, centroids.T)  # cosine sim
        cluster_labels = sim.argmax(dim=1)  # [N]
        new_centroids = []
        for k in range(n_clusters):
            mask = (cluster_labels == k)
            if mask.sum() == 0:
                new_centroids.append(centroids[k].unsqueeze(0))
            else:
                mean = X[mask].mean(dim=0, keepdim=True)
                mean = F.normalize(mean, dim=1)
                new_centroids.append(mean)
        centroids = torch.cat(new_centroids, dim=0)
    return cluster_labels, centroids


def initialize_with_spherical_kmeans_cc(model, features, cfg):
    """
    使用 Spherical KMeans 初始化 instance_projector 和 cluster_projector。
    """
    # Step 1: Z-score 标准化 + L2 normalize
    features_zscore = (features - features.mean(1, keepdim=True)) / features.std(1, keepdim=True)
    features_zscore = F.normalize(features_zscore, dim=1)  # [N, D]

    K= cfg['backbone']['nclusters']
    
    # Step 2: 第一层权重 W1 ← spherical kmeans (n_clusters = D)
    D = features.shape[1]
    proto_labels, W1 = spherical_kmeans_torch(features, n_clusters=D)
    H = torch.relu(torch.mm(features, W1.T))
    H = F.normalize(H, dim=1)
    cluster_labels, W2 = spherical_kmeans_torch(H, n_clusters=K)

    # Step 4: 正交化（可选，提升判别性）
    W1 = orth_train(W1, D, use_relu=True)  # [D, D]
    W2 = orth_train(W2, K, use_relu=True)  # [K, D]

    # Step 5: 初始化 instance_projector 第一层
    torch.nn.init.zeros_(model.module.instance_projector[0].bias)
    model.module.instance_projector[0].weight.data = W1.clone()

    # Step 6: 初始化 cluster_projector 的两层
    torch.nn.init.zeros_(model.module.cluster_projector[0].bias)
    model.module.cluster_projector[0].weight.data = W1.clone()

    torch.nn.init.zeros_(model.module.cluster_projector[2].bias)
    model.module.cluster_projector[2].weight.data = W2.clone()

    print(f"Spherical KMeans init done: W1={W1.shape}, W2={W2.shape}")

def main():
    args = FLAGS.parse_args()

    import random
    import numpy as np
    # seed = 1024
    seed = args.seed
    print("seed: ", seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(True)

    cfg = parse_cfg(args.config_env, args.config_exp)
    print(cfg)

    # Data
    print('Get dataset and dataloaders')
    strong_transformations = get_strong_transformations(cfg)
    standard_transformations = get_standard_transformations(cfg)
    val_transformations = get_val_transformations(cfg)

    # train_dataset = get_train_dataset(cfg, {'standard': val_transformations,
    #                                         'augment': strong_transformations},
    #                                     split=cfg['data']['split'], augmented = True)
    
    train_dataset = get_train_dataset(cfg, {'standard': standard_transformations,
                                            'augment': strong_transformations},
                                        split=cfg['data']['split'], augmented = True)
    val_dataset = get_val_dataset(cfg, val_transformations)
    train_dataloader = get_train_dataloader(cfg, train_dataset)
    val_dataloader = get_val_dataloader(cfg, val_dataset)
    print('Strong transforms:', strong_transformations)
    print('Standard transforms:', standard_transformations)
    print('Validation transforms:', val_transformations)
    print('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)))

    # Model
    print('Get model')
    model = get_model(cfg, cfg['pretext']['enable'])
    print(model)

    # Optimizer
    print('Get optimizer')
    optimizer = get_optimizer(cfg, model)
    print(optimizer)

    # Loss function
    print('Get loss')
    criterion = None
    #criterion.cuda()
    #print(criterion)

    # wandb
    wandb.watch(model, log="all")

 
    # Evaluate
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cc_dir'], 0, 0,
                                          predictions, title=cfg['cluster_eval']['plot_title'],
                                          compute_confusion_matrix=False)
    print(clustering_stats)
    
    #init
    features_zscore = (features - features.mean(1, keepdim=True)) / features.std(1, keepdim=True)
    features_zscore = F.normalize(features_zscore, dim=1)  # [N, D]
    KMeans_D = PyTorchKMeans(init='k-means++', n_clusters=features.shape[1], verbose=False)
    proto_labels = KMeans_D.fit_predict(features_zscore)
    W1 = KMeans_D.cluster_centers_  # [D, D]
    H = torch.mm(features, W1.T)
    H = torch.relu(H)
    K = cfg['backbone']['nclusters']
    KMeans_K = PyTorchKMeans(init='k-means++', n_clusters=K, verbose=False)
    cluster_labels = KMeans_K.fit_predict(H)
    W2 = KMeans_K.cluster_centers_  # [K, D]
    W1 = orth_train(W1, features.shape[1], use_relu=True)  # [D, D]
    W2 = orth_train(W2, K, use_relu=True)                  # [K, D]
    # 第一层
    torch.nn.init.zeros_(model.module.instance_projector[0].bias)
    model.module.instance_projector[0].weight.data = W1.clone()
    # 第二层
    torch.nn.init.zeros_(model.module.cluster_projector[0].bias)
    model.module.cluster_projector[0].weight.data = W1.clone()  # 同样的 W1 作为 cluster_projector 的第一层
    torch.nn.init.zeros_(model.module.cluster_projector[2].bias)
    model.module.cluster_projector[2].weight.data = W2.clone()  # 第二层为分类原型

    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cc_dir'], 0, 0,
                                          predictions, title=cfg['cluster_eval']['plot_title'],
                                          compute_confusion_matrix=False)
    print('after ini: ', clustering_stats)

    k = cfg.get('k', 10)
    alpha = cfg.get('alpha', 1.0)
    init_head_doublelayer_bias(cfg,model, features, k=k, alpha=alpha)
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cc_dir'], 0, 0,
                                          predictions, title=cfg['cluster_eval']['plot_title'],
                                          compute_confusion_matrix=False)
    print('after ini: ', clustering_stats)


    # Checkpoint
    if os.path.exists(cfg['cc_checkpoint']):
        print('Restart from checkpoint {}'.format(cfg['cc_checkpoint']))
        checkpoint = torch.load(cfg['cc_checkpoint'], map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
    else:
        print('No checkpoint file at {}'.format(cfg['cc_checkpoint']))
        start_epoch = 0


    best_acc = 0
    log_path = os.path.join(cfg['cc_dir'], 'training_log.log')

    thresh=cfg.get('t', 0.1)
    window = cfg.get('w', 3)
    tracker = SampleMasterTracker(cfg, num_samples=len(train_dataloader.dataset),
                              delta_thresh=thresh, window=window)
    
    # Main loop
    print('Starting main loop', 'blue')
    for epoch in range(start_epoch, cfg['max_epochs']):
        print('Epoch %d/%d' % (epoch + 1, cfg['max_epochs']))
        # Adjust lr

        lr = cfg['optimizer']['lr']
        print('Adjusted learning rate to {:.5f}'.format(lr))

        log_file = open(log_path, 'a')
        log_file.write(f'Epoch {epoch+1} - Validation prediction\n')
        # Train
        print('Train ...')
        #cc_train(cfg, clustering_stats, train_dataloader, model, criterion, optimizer, epoch)
        cc_train_sample(cfg, clustering_stats, train_dataloader, model, optimizer, tracker, stabilityloss=True)

        # scheduler.step()
        # Evaluate
        if (epoch+1) % cfg['cluster_eval']['eval_freq'] == 0:
            print('Make prediction on validation set ...')
            predictions = get_predictions(cfg, val_dataloader, model)
            clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['cc_dir'], epoch, 0, predictions,
                                                  title=cfg['cluster_eval']['plot_title'],
                                                  compute_confusion_matrix=False,  _indices = indices)
            print(clustering_stats)
            log_file.write(f'EVA: {clustering_stats}\n\n')
              
        log_file.write(f'Train data removed - {len(tracker.removed)}/{len(train_dataloader)}\n')      
        log_file.close()    

        # Checkpoint
        print('Checkpoint ...')
        torch.save({'optimizer': optimizer.state_dict(),
                    'model': model.state_dict(),
                    'epoch': epoch + 1},
                    cfg['cc_checkpoint'])
        
        if best_acc < clustering_stats['ACC']:
            torch.save({'optimizer': optimizer.state_dict(),
                    'model': model.state_dict(),
                    'epoch': epoch + 1},
                       cfg['cc_best_model'])
            best_acc = clustering_stats['ACC']

    # Evaluate and save the final model
    print('Evaluate best model at the end')
    predictions = get_predictions(cfg, val_dataloader, model)
    checkpoint = torch.load(cfg['cc_best_model'], map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    try:
        clustering_stats = hungarian_evaluate(cfg, cfg['cc_dir'], cfg['max_epochs'], 0, predictions,
                                              title=cfg['cluster_eval']['plot_title'],
                            class_names=val_dataloader.dataset.classes,
                            compute_confusion_matrix=True,
                            confusion_matrix_file=os.path.join(cfg['cc_dir'], 'confusion_matrix.png'))
    except:
        clustering_stats = hungarian_evaluate(cfg, cfg['cc_dir'], cfg['max_epochs'], 0, predictions,
                                              title=cfg['cluster_eval']['plot_title'],
                                              class_names=list([str(i) for i in range(cfg['backbone']['nclusters'])]),
                                              compute_confusion_matrix=True,
                                              confusion_matrix_file=os.path.join(cfg['cc_dir'],
                                                                                 'confusion_matrix.png'))
    print(clustering_stats)
    log_file = open(log_path, 'a')
    log_file.write(f'best EVA: {clustering_stats}\n')
    log_file.close()

if __name__ == "__main__":
    main()