"""Evaluate BFR on spurious correlations datasets."""

import torch
import torchvision
from pydantic.typing import all_literal_values
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import os
import tqdm
import argparse
import sys
from collections import defaultdict
import json
from functools import partial
import pickle
import time

from utils import Logger, AverageMeter, set_seed, evaluate, add_label_noise
from data_utils import get_dataset, get_transform, MySubset, split_dataset
from utils import get_last_layer_retraining_args, get_model, get_classifier, get_feature_extractor, featurize_dataset, get_grouper, get_predictions
from bfr import BFRTrainer, SubpopDataset, IndexedDataset
from baselines import Trainer
from model_helpers import get_uncertain_datasets, Proj_Model, KL_GamWei
from wb_data import WaterBirdsDataset, get_loader, get_transform_cub
from copy import deepcopy

os.environ['CUDA_VISIBLE_DEVICES'] = '0, 2'
args = get_last_layer_retraining_args()

train_transform, test_transform = get_transform(args.dataset, args)
dataset = get_dataset(dataset=args.dataset, root_dir=args.root_dir, download=False)

train_dataset = dataset.get_subset("train", transform=train_transform)
all_val_dataset = dataset.get_subset("val", transform=test_transform)
test_dataset = dataset.get_subset("test", transform=test_transform)

ckpt_path = os.path.join(args.base_dir, "final_checkpoint.pt")

# Load model
n_classes = train_dataset.n_classes
model = get_model(args.model, n_classes)
model.load_state_dict(torch.load(ckpt_path),strict=False)
model.cuda()
model.eval()
model_copy = deepcopy(model)
model_copy.cuda()
model_copy.eval()
classifier = get_classifier(model)
featurizer = get_feature_extractor(model)

train_cache_file = os.path.join(args.base_dir, f'train_cache_{args.model}.pkl')
val_cache_file = os.path.join(args.base_dir, f'val_cache_{args.model}.pkl')
test_cache_file = os.path.join(args.base_dir, f'test_cache_{args.model}.pkl')

def load_or_save_cache(cache_file, dataset, featurizer, args):
    if not os.path.exists(cache_file):
        x, y, metadata = featurize_dataset(dataset, featurizer, args)
        with open(cache_file, 'wb') as f:
            pickle.dump((x, y, metadata), f)
    else:
        print(f"Loading cache from {cache_file}")
        with open(cache_file, 'rb') as f:
            x, y, metadata = pickle.load(f)
    return x, y, metadata

print("Featurizing datasets ...")
if args.no_cache:
    train_x, train_y, train_metadata = featurize_dataset(train_dataset, featurizer, args)
    all_val_x, all_val_y, all_val_metadata = featurize_dataset(all_val_dataset, featurizer, args)
    test_x, test_y, test_metadata = featurize_dataset(test_dataset, featurizer, args)
else:
    train_x, train_y, train_metadata = load_or_save_cache(train_cache_file, train_dataset, featurizer, args)
    all_val_x, all_val_y, all_val_metadata = load_or_save_cache(val_cache_file, all_val_dataset, featurizer, args)
    test_x, test_y, test_metadata = load_or_save_cache(test_cache_file, test_dataset, featurizer, args)
print(f"Train set shape: {train_x.shape}")

target_resolution = (224, 224)
test_transform = get_transform_cub(target_resolution=target_resolution,
                                   train=False, augment_data=False)
testset_dict = {
    'wb_val': WaterBirdsDataset(basedir=args.data_dir, split="val",
                                transform=test_transform),
}

loader_kwargs = {'batch_size': 64,
                 'num_workers': 4, 'pin_memory': True,
                 "reweight_places": None}
test_loader_dict = {}
for test_name, testset_v in testset_dict.items():
    test_loader_dict[test_name] = get_loader(
        testset_v, train=False, reweight_groups=None,
        reweight_classes=None, **loader_kwargs)
selected_indices, top_uncertain_indices, low_uncertain_indices, remaining_indices, top_threshold_indices, acc_uncertain_indices, inacc_certain_indices, au_ic_indices, ac_au_indices = get_uncertain_datasets(model_copy,test_loader_dict["wb_val"],topk_ratio=0.2,return_indices_only=True)
val_idx3 = np.array(top_uncertain_indices)

val_idx1, val_idx2 = split_dataset(all_val_dataset,
                                       args.val_subsample)  # seed is automatically applied for reproducibility
print(f"Using a subsample of val set of size {len(val_idx1)}")

set_seed(args.seed)

grouper = get_grouper(dataset, args.multiple_groupers)
train_g = grouper.metadata_to_group(train_metadata)
all_val_g = grouper.metadata_to_group(all_val_metadata)
test_g = grouper.metadata_to_group(test_metadata)

val_x, val_y, val_metadata, val_g = all_val_x[val_idx1], all_val_y[val_idx1], all_val_metadata[val_idx1], all_val_g[val_idx1]

if args.train_val:
    # Use the top high uncertain val set as the target set
    train_x, train_y, train_g = all_val_x[val_idx3], all_val_y[val_idx3], all_val_g[val_idx3]
    print('uncertainty set group distribution')
    tmp_values, tmp_cnts = np.unique(train_g, return_counts=True)
    for v, c in zip(tmp_values, tmp_cnts):
        print(f"group {v}: {c}")

train_group_indices = None
val_group_indices = None
test_group_indices = None

train_dataset = SubpopDataset(train_x, train_y, train_g, train_metadata, all_group_indices=train_group_indices)
val_dataset = SubpopDataset(val_x, val_y, val_g, val_metadata, all_group_indices=val_group_indices)
all_val_dataset = SubpopDataset(all_val_x, all_val_y, all_val_g, all_val_metadata, all_group_indices=val_group_indices)
test_dataset = SubpopDataset(test_x, test_y, test_g, test_metadata, all_group_indices=test_group_indices)

# The dataset needs to be wrapped to include indices in each batch
train_dataset = IndexedDataset(train_dataset)
val_dataset = IndexedDataset(val_dataset)
all_val_dataset = IndexedDataset(all_val_dataset)
test_dataset = IndexedDataset(test_dataset)


if not args.no_target and args.val_subsample is not None and args.val_subsample != 1:
    print(f"Retraining on a separate target set of size {len(val_idx2)}")
    target_x, target_y, target_metadata = all_val_x[val_idx2], all_val_y[val_idx2], all_val_metadata[val_idx2]
    target_g = all_val_g[val_idx2]
    target_group_indices = grouper.metadata_to_group_indices(target_metadata) if args.multiple_groupers else None
    target_dataset = SubpopDataset(target_x, target_y, target_g, target_metadata, all_group_indices=target_group_indices)
    target_dataset = IndexedDataset(target_dataset)
else:
    # use the val set as the target set, but this can result in overfitting to the val set
    target_x, target_y, target_metadata = val_x, val_y, val_metadata
    target_dataset = val_dataset
    print(f"Using val set of size {len(val_dataset)} as the target set")

writer = None

set_seed(args.seed)
model = Proj_Model(train_x.shape[1], n_classes)

train_results = None

trainer = BFRTrainer(args, model, all_val_dataset, target_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, writer=writer)

if args.verbose:
    start_time = time.time()
train_results, best_w, best_model = trainer.solve()
if args.verbose:
    end_time = time.time()
    print(f"Training time: {end_time - start_time:.2f} seconds with iterations {args.max_outer_iter}")
best_train, best_val, best_test = train_results['best_train_acc'], train_results['best_val_wg'], train_results['best_test_wg']
print(f"Best train: {best_train}, Best val: {best_val}, Best test: {best_test}")

y_preds, y_true = get_predictions(best_model, test_dataset)
official_results, official_results_str = dataset.eval(y_preds, y_true, test_metadata) 

print("Test results:", official_results_str)
if 'adj_acc_avg' in official_results.keys(): # For waterbirds because the val/test sets are balanced
    avg_test_acc = official_results['adj_acc_avg']
else:
    avg_test_acc = official_results['acc_avg']
wg_test_acc = official_results['acc_wg']
hparams = vars(args)
result_json = {
    'val_acc': best_val,
    'avg_test_acc': avg_test_acc,
    'wg_test_acc': wg_test_acc, 
    'seed': args.seed,
    'hparams': hparams,
    'all_results': train_results,
}
if train_results is not None:
    result_json['avg_test_acc_unofficial'] = train_results['best_test_acc']
    result_json['wg_test_acc_unofficial'] = train_results['best_test_wg']

os.makedirs(args.output_dir, exist_ok=True)
mode = 'a+'
with open(os.path.join(args.output_dir, 'results.json'), mode) as f:
    f.write(json.dumps(result_json) + '\n')

with open(os.path.join(args.output_dir, 'done'), 'w') as f:
    f.write('done')

# if args.save_stats:
group_file = os.path.join(args.output_dir, 'group_stats.pkl')
weights_file = os.path.join(args.output_dir, 'weights.pkl')
with open(weights_file, 'wb') as f:
    pickle.dump(best_w.cpu().numpy(), f)
with open(group_file, 'wb') as f:
    pickle.dump(train_g.cpu().numpy(), f)

if writer is not None:
    writer.log({"group_stats": train_g})
    writer.log({"weights": best_w})

try:
    # save noise indices
    noise_file = os.path.join(args.output_dir, 'noise_indices.pkl')
    with open(noise_file, 'wb') as f:
        pickle.dump(noise_indices, f)
except NameError:
    pass



