
import logging
import re
import warnings
from torch.utils.data import DataLoader, Subset
import wandb
import torch.nn.functional as F
from cdc.utils.torch_clustering import PyTorchKMeans

from cdc.data.memory import MemoryBank, fill_memory_bank

logging.captureWarnings(True)
warnings.filterwarnings('always', category=DeprecationWarning,
                        module=r'^{0}.'.format(re.escape(__name__)))
warnings.warn("This is a DeprecationWarning",category=DeprecationWarning)

import argparse
import os
import torch
from cdc.args import parse_cfg, get_model, get_strong_transformations,\
    get_val_transformations, get_standard_transformations,\
    get_train_dataloader, get_val_dataloader,\
    get_train_dataset,get_val_dataset, get_optimizer,\
    get_criterion, adjust_learning_rate
from cdc.utils.evaluate_utils import get_predictions, \
    hungarian_evaluate_hard, hign_conf_evaluation, save_top_images, save_bottom_images, \
        create_feature_extractor, visualize_tsne_with_confidence, generate_and_plot_tsne, kmeans_cluster_and_visualize_by_true_label

from cdc.methods.scan_train import scan_train, scan_train_sample, init_head_singlelayer_bias
from cdc.data.waterbirds_dataset import compute_worst_acc
from cdc.methods.dyn_train import SampleMasterTracker, train_cali_sample

FLAGS = argparse.ArgumentParser(description='SCAN Model')
FLAGS.add_argument('--config_env', default='scripts/scan/env.yaml', help='Location of path config file')
FLAGS.add_argument('--config_exp', default='scripts/scan/cifar20/scan2_ini_bias_a20_sample_stabilityloss_t1_seed5.yaml', help='Location of experiments config file')
FLAGS.add_argument('--seed', default=5, type=int)

os.environ["WANDB_API_KEY"] = '2a4485eff00bb9efe7db48f5ca413f10466663b4'
os.environ["WANDB_MODE"]="offline"

def main():
    args = FLAGS.parse_args()

    import random
    import numpy as np
    
    # seed = 1024
    seed = args.seed
    print("seed: ", seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(True)
    
    cfg = parse_cfg(args.config_env, args.config_exp)
    print(cfg)

    # Data
    print('Get dataset and dataloaders')
    strong_transformations = get_strong_transformations(cfg)
    standard_transformations = get_standard_transformations(cfg)
    val_transformations = get_val_transformations(cfg)

    # Model
    print('Get model')
    model = get_model(cfg, cfg['pretext']['enable'])
    print(model)

    #class_list = cfg['data']['class_list'] 

    print('Mining Neighbors ...')
    base_dataset = get_train_dataset(cfg, val_transformations, split=cfg['data']['split'])  # Dataset w/o augs for knn eval 
    from cdc.data.collate import collate_custom
    base_dataloader = torch.utils.data.DataLoader(base_dataset, num_workers=cfg['data']['num_workers'],
                                                    batch_size=1000, pin_memory=True, collate_fn=collate_custom,
                                                    drop_last=False, shuffle=False)
    memory_bank_base = MemoryBank(len(base_dataset),
                                cfg['backbone']['feat_dim'],
                                cfg['backbone']['nclusters'],
                                cfg['backbone']['temperature']
                                )
    memory_bank_base.cuda()
    fill_memory_bank(base_dataloader, model, memory_bank_base)
    distance, indices, acc = memory_bank_base.mine_nearest_neighbors(30)
    print(acc)
    np.save(cfg['topk_neighbors_train_path'], indices)
    np.save(cfg['topk_neighbors_train_dist'], distance)
    
    """ train_dataset = get_train_dataset(cfg, strong_transformations,
                                        split=cfg['data']['split'],
                                      to_neighbors_dataset=True) """
    train_dataset = get_train_dataset(cfg, {'val': val_transformations,
                                            'standard': standard_transformations,
                                            'augment': strong_transformations},
                                        split=cfg['data']['split'],
                                      to_neighbors_dataset=True)
    train_dataloader = get_train_dataloader(cfg, train_dataset)

    
    val_dataset = get_val_dataset(cfg, val_transformations)
    # train_dataloader = get_train_dataloader(cfg, train_dataset)
    val_dataloader = get_val_dataloader(cfg, val_dataset)
    print('Strong transforms:', strong_transformations)
    print('Standard transforms:', standard_transformations)
    print('Validation transforms:', val_transformations)
    print('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)))

    #import pdb; pdb.set_trace()

    # Optimizer
    print('Get optimizer')
    optimizer = get_optimizer(cfg, model)
    print(optimizer)

    # Loss function
    print('Get loss')
    criterion = get_criterion(cfg)
    criterion.cuda()
    print(criterion)
    
    # wandb
    wandb.watch(model, log="all")

    # Evaluate
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['scan_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False, easy=True)
    print(clustering_stats)
    #init
    features_zscore = (features - features.mean(1).reshape(-1, 1)) / features.std(1).reshape(-1, 1)
    #features_zscore = features.detach()
    features_zscore = F.normalize(features_zscore, dim=1)
    KMeans_ = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False)
    _ = KMeans_.fit_predict(features_zscore)
    W1 = KMeans_.cluster_centers_
    torch.nn.init.zeros_(model.module.cluster_head[0].bias)
    model.module.cluster_head[0].weight.data = W1.clone()
    
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    #import pdb; pdb.set_trace()
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['scan_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False, easy=True)
    print('ini:', clustering_stats)

    k = cfg.get('k', 10)
    alpha = cfg.get('alpha', 1.0)
    init_head_singlelayer_bias(cfg, model, features, k=k, alpha=alpha)
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['scan_dir'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False, easy=True)
    print('ini-bias: ', clustering_stats)

    # Checkpoint
    if os.path.exists(cfg['scan_checkpoint']):
        print('Restart from checkpoint {}'.format(cfg['scan_checkpoint']))
        checkpoint = torch.load(cfg['scan_checkpoint'], map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
    else:
        print('No checkpoint file at {}'.format(cfg['scan_checkpoint']))
        start_epoch = 0
        
    best_acc = -1
    log_path = os.path.join(cfg['scan_dir'], 'training_log.log')

    thresh=cfg.get('t', 0.1)
    window = cfg.get('w', 3)
    tracker = SampleMasterTracker(cfg, num_samples=len(train_dataloader.dataset),
                              delta_thresh=thresh, window=window)

    # Main loop
    print('Starting main loop', 'blue')
    for epoch in range(start_epoch, cfg['max_epochs']):
        print('Epoch %d/%d' % (epoch + 1, cfg['max_epochs']))
        # Adjust lr
        lr = adjust_learning_rate(cfg, optimizer, epoch)
        print('Adjusted learning rate to {:.5f}'.format(lr))

        # Train
        print('Train ...')
        #scan_train(cfg, clustering_stats, train_dataloader, model, criterion, optimizer, predictions[0]['predictions'])
        scan_train_sample(cfg, train_dataloader, model, criterion, optimizer, predictions[0]['predictions'],tracker=tracker, stabilityloss=True)
        # scheduler.step()
        # Evaluate
        
        log_file = open(log_path, 'a')
        log_file.write(f'Epoch {epoch+1} - Validation prediction\n')
        log_file.write(f'Train data removed - {len(tracker.removed)}/{len(train_dataloader)}\n')
        
        if (epoch+1) % cfg['cluster_eval']['eval_freq'] == 0:
            print('Make prediction on validation set ...')
            predictions = get_predictions(cfg, val_dataloader, model)
            clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['scan_dir'], epoch, 0, predictions,
                                                  title=cfg['cluster_eval']['plot_title'],
                                                  compute_confusion_matrix=False, _indices = indices)
            print(clustering_stats)

            
        log_file.write(f'EVA: {clustering_stats}\n\n')
        #import pdb; pdb.set_trace()
        
        print("train samples: ", len(train_dataloader.dataset))
        print("val samples: ", len(val_dataloader.dataset))
        
        log_file.write(f'Train samples: {len(train_dataloader.dataset)}\n')
        log_file.write(f'Val samples: {len(val_dataloader.dataset)}\n\n')
        
        # Checkpoint
        print('Checkpoint ...')
        torch.save({'optimizer': optimizer.state_dict(),
                    'model': model.state_dict(),
                    'epoch': epoch + 1},
                    cfg['scan_checkpoint'])
        
        if best_acc < clustering_stats['ACC']:
            torch.save({'optimizer': optimizer.state_dict(),
                    'model': model.state_dict(),
                    'epoch': epoch + 1},
                       cfg['scan_best_model'])
            best_acc = clustering_stats['ACC']
        
        # High confidence samples
        """ if (epoch + 1) % cfg['cluster_eval']['select_freq'] == 0 and cfg['cluster_eval']['select_enable']:
            hign_conf_evaluation(cfg, model, train_dataloader, clustering_stats) """

    # Evaluate and save the final model
    print('Evaluate best model at the end')
    checkpoint = torch.load(cfg['scan_best_model'], map_location='cpu')
    #import pdb; pdb.set_trace()
    model.load_state_dict(checkpoint['model'])
    
    predictions, features = get_predictions(cfg, val_dataloader, model, return_features=True)
    try:
        clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['scan_dir'], cfg['max_epochs'], 0, predictions, title=cfg['cluster_eval']['plot_title'],
                            class_names=val_dataloader.dataset.classes,
                            compute_confusion_matrix=True,
                            confusion_matrix_file=os.path.join(cfg['scan_dir'], 'confusion_matrix.png'), _indices = indices)
    except:
        clustering_stats, indices = hungarian_evaluate_hard(cfg, cfg['scan_dir'], cfg['max_epochs'], 0, predictions,
                                            title=cfg['cluster_eval']['plot_title'],
                                            class_names=list([str(i) for i in range(cfg['backbone']['nclusters'])]),
                                            compute_confusion_matrix=True,
                                            confusion_matrix_file=os.path.join(cfg['scan_dir'],
                                                                                'confusion_matrix.png'), _indices = indices)
    print(clustering_stats)
    
    """ generate_and_plot_tsne(cfg, val_dataloader, model, indices_to_highlight=indices, save_path=cfg['scan_dir'], epoch=999) """
    
    """ results = kmeans_cluster_and_visualize_by_true_label(
        class_list,
        features=features,
        dataset=val_dataloader.dataset,
        n_clusters=2,
        save_path=cfg['scan_dir'],
        epoch = 999
    )
    
    print(results) """
    
    if cfg['data']['dataset'] in ["waterbirds"]:
        compute_worst_acc(model, val_dataloader, cfg['method'])

if __name__ == "__main__":
    main()