'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from sklearn.metrics import f1_score
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, densenet121, resnet34
from utils.dataloader import WildsDataset
from utils.utils import ece_score
from wilds.datasets.iwildcam_dataset import IWildCamDataset
import os, glob
import argparse
import pdb
from models import *
from torch.autograd import Variable
import numpy as np
import random

parser = argparse.ArgumentParser(description='cSG-MCMC CIFAR10 Ensemble')
parser.add_argument('--dir', type=str, default=None, required=True, help='path to checkpoints (default: None)')
parser.add_argument('--data_path', type=str, default='data',  metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--data_type', type=str, default='iwildcam',
                    help='Dataset type of WILDS (default: None)')
parser.add_argument('--test_type', type=str, default='test',
                    help='Test dataset type of WILDS: test -> OOD, idtest -> ID (default: test)')
parser.add_argument('--device_id',type = int, help = 'device id to use')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--n_model', type=int, default=12, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--pretrained', action='store_true',
                    help='Use ResNet50, pretrained with ImageNet')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')
parser.add_argument('--ERM', action='store_true')
parser.add_argument('--best', action='store_true')
parser.add_argument('--bestwo', action='store_true')
args = parser.parse_args()
device_id = args.device_id
use_cuda = torch.cuda.is_available()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

# Data
dataset = WildsDataset(args.data_type, args, test_type=args.test_type)
trainloader, _, testloader = dataset.get_loader(args) 

# Model
criterion = nn.CrossEntropyLoss()

print('==> Building model..')
if args.data_type in ['iwilds', 'rxrx1', 'waterbirds', 'celebA']:
    net = resnet50(pretrained=False)#* Load pre-trained model
    net.fc = torch.nn.Linear(net.fc.in_features, dataset.target_dim) #* change output dimension
elif args.data_type in ['fmow', 'camelyon17']:
    net = densenet121(pretrained=False)
    net.classifier = torch.nn.Linear(net.classifier.in_features, dataset.target_dim) #* change output dimension
elif args.data_type == 'ogb-molpcba':
    from models.gnn import *
    net = GINVirtual()
    criterion = nn.BCEWithLogitsLoss()
    metric = Evaluator('ogbg-molpcba')
else:
    from models import *
    net = ResNet18(num_classes=dataset.target_dim)

if use_cuda:
    net.cuda(device_id)
    cudnn.benchmark = True
    cudnn.deterministic = True

def get_accuracy(truth, pred):
    assert len(truth)==len(pred)
    right = 0
    for i in range(len(truth)):
        if truth[i]==pred[i]:
             right += 1.0
    return right/len(truth)

def test():
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pred_list = []
    truth_res = []
    metadata_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(device_id), targets.cuda(device_id)
            truth_res += list(targets.cpu().data)
            metadata_list += list(metadata.cpu().data)
            outputs = net(inputs)
            pred_list.append(F.softmax(outputs,dim=1))
            loss = criterion(outputs, targets)
            test_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss/len(testloader), correct, total,
        100. * correct.item() / total))
    pred_list = torch.cat(pred_list,0)
    truth_res = torch.stack(truth_res)
    metadata_list = torch.stack(metadata_list)
    return pred_list, truth_res, metadata_list

pred_list = []
model_paths = glob.glob(args.dir + f'/{args.data_type}_model_*.pt')

if args.best:
    model_paths = glob.glob(args.dir + f'/{args.data_type}_model_best_val.pt')

if args.bestwo:
    model_paths = glob.glob(args.dir + f'/{args.data_type}_model_bestwo_val.pt')

num_model = len(model_paths) 
print(f"Number of models :{len(model_paths)}")
for model_path in model_paths:
    net.load_state_dict(torch.load(model_path))
    pred, truth_res, metadatas = test()
    pred_list.append(pred)

fake = sum(pred_list)/num_model
values, pred_label = torch.max(fake,dim = 1)
pred_res = list(pred_label.cpu().data)
acc = get_accuracy(truth_res, pred_res)
f1 = f1_score(truth_res, pred_res, average='macro')
eval_result, eval_result_str = dataset.dataset.eval(torch.stack(pred_res), truth_res, metadatas)

print(f'Accuracy :{acc:.4f}, Macro F1 score: {f1:.4f}')
print(f'{eval_result_str}')

##########################################################################
group_keys = torch.unique(metadatas, dim=0).numpy().tolist()
group_keys = list(map(str, group_keys))
group_dict = dict(zip(group_keys, [{'preds':[], 'trues':[]} for x in range(len(group_keys))]))

for f, y, m in zip(fake, truth_res, metadatas):
    k = str(m.tolist())
    group_dict[k]['preds'].append(f.cpu().numpy())
    group_dict[k]['trues'].append(y.item())

total_preds = []
total_trues = []
for i in range(len(group_dict.keys())):
    total_preds += group_dict[k]['preds']
    total_trues += group_dict[k]['trues']
total_preds, total_trues = np.array(total_preds), np.array(total_trues)

group_list = list(group_dict.values())
ece_list = []
for i in range(len(group_list)):
    per_group_pred = np.array(group_list[i]['preds'])
    per_group_true = np.array(group_list[i]['trues'])
    ece_list.append(ece_score(per_group_pred, per_group_true))

print(f"Worst ECE: {max(ece_list):.5f}, Average ECE: {np.mean(ece_list):.5f}")
print(f"ECE: {ece_score(total_preds, total_trues):.5f}")

eval_result['worst_ece'] = max(ece_list)
eval_result['average_ece'] = np.mean(ece_list)

import json
if args.best:
    json_name = 'eval_result_bestacc.json' 
elif args.bestwo:
    json_name = 'eval_result_bestwoacc.json' 
else:
    json_name = 'eval_result.json'

with open(f'{args.dir}/{json_name}.json','w') as f:
    json.dump(eval_result, f)
