# -*- coding: utf-8 -*-

import argparse
import os
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
# import sklearn.metrics as sm
# import pandas as pd
# import sklearn.metrics as sm
import random
import numpy as np

from pytorch_transformers import AdamW, WarmupLinearSchedule

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from torch.utils.tensorboard import SummaryWriter

from resnet import get_model
# from load_corrupted_data import CIFAR10, CIFAR100
# from datasets import CelebA
from data_loader import prepare_data
from arguments import get_arguments

torch.multiprocessing.set_sharing_strategy('file_system')

args = get_arguments()
use_cuda = True
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")


print()
print(args)

def build_model():
    
    model = get_model(args)
        
    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True

    return model

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def adjust_learning_rate(optimizer, epoch):
    lr = args.lr
    if args.lr_decay is not None:
        for decay_epoch in args.lr_decay:
            lr *= (0.1 ** int(epoch >= decay_epoch))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
        
        
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


        
class GroupEMA:
    
    def __init__(self, size, step_size=0.01):
        self.step_size = step_size
        # self.exp_avg_loss = torch.zeros(size).cuda()
        # self.exp_avg_initialized = torch.zeros(size).byte().cuda()
        self.group_weights = torch.ones(size).cuda() / size
        
        
    def update(self, group_loss, group_count):
        
        # prev_weights = (1 - self.gamma * (group_count > 0).float()) * (self.exp_avg_initialized > 0).float()
        # curr_weights = 1 - prev_weights
        # self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss * curr_weights
        # self.exp_avg_initialized = (self.exp_avg_initialized > 0) + (group_count > 0)
        
        self.group_weights = self.group_weights * torch.exp(self.step_size * group_loss.data)
        self.group_weights = self.group_weights / self.group_weights.sum()
        
        weighted_loss = group_loss @ self.group_weights
        
        return weighted_loss


def inference(model, test_loader):
    
    model.eval()

    ys = []
    bs = []
    preds = []
    corrects = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs, targets = batch['x'].to(device), batch['y'].to(device)
            input_ids = inputs[:, :, 0]
            input_masks = inputs[:, :, 1]
            segment_ids = inputs[:, :, 2]
            y_hat = model(
                input_ids=input_ids,
                attention_mask=input_masks,
                token_type_ids=segment_ids,
                labels=targets,
            )[1]  # [1] returns logits
            test_loss = F.cross_entropy(y_hat, targets, reduction='none')
            _, predicted = y_hat.cpu().max(1)
            correct = predicted.eq(batch['y'])
            
            ys.append(batch['y'])
            bs.append(batch['a'])
            preds.append(predicted.squeeze())
            corrects.append(correct)

    ys = torch.cat(ys)
    bs = torch.cat(bs)
    preds = torch.cat(preds)
    corrects = torch.cat(corrects)
    
    num_groups = 6 if 'mnli' in args.dataset else 4
    group = ys*2 + bs
    group_indices = dict()
    for i in range(num_groups):
        group_indices[i] = np.where(group == i)[0]
    
    print('')
    worst_accuracy = 100
    worst_bias_accuracy = 100
    for i in range(num_groups):
        correct = corrects[group_indices[i]].sum().item()
        accuracy = 100. * correct / len(group_indices[i])
        print(f'Test set - group {i}: Accuracy: {correct}/{len(group_indices[i])}({accuracy:.4f}%)')
    
    return preds



train_loader, _, valid_loader, test_loader = prepare_data(args)
# create model
model = build_model()

import pandas as pd
metadata_df = pd.read_csv('/home/june/datasets/civilcomments_v1.0/all_data_with_identities_backtrans.csv')

train_indices = np.where(metadata_df['split'] == 'train')[0]
valid_indices = np.where(metadata_df['split'] == 'val')[0]
test_indices = np.where(metadata_df['split'] == 'test')[0]

# weight_ema = EMA(size=len(train_loader.dataset), alpha=0.7)
num_groups = 6 if 'mnli' in args.dataset else 4

ckpt_dir = os.path.join('results', args.dataset, args.name)
log_dir = os.path.join('summary', args.dataset, args.name)

state_dict = torch.load(os.path.join(ckpt_dir, f'epoch_{args.epochs}.pth'))['model']
model.load_state_dict(state_dict)
    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
    
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    
writer = SummaryWriter(log_dir)

def main():
    best_acc = 0
    
    valid_preds = inference(model, valid_loader)
    
    for attr_name in ['male', 'female', 'LGBTQ', 'christian', 'muslim', 'other_religions', 'black', 'white']:
        a = np.array(metadata_df[attr_name][valid_indices] >= 0.5).astype('long')
        b = np.array(metadata_df['toxicity'][valid_indices] >= 0.5).astype('long')
        print((a*b*valid_preds.numpy()).sum(), (a*(1-b)*(1-valid_preds).numpy()).sum())
        
    print(' ')
    
    test_preds = inference(model, test_loader)
    
    for attr_name in ['male', 'female', 'LGBTQ', 'christian', 'muslim', 'other_religions', 'black', 'white']:
        a = np.array(metadata_df[attr_name][test_indices] >= 0.5).astype('long')
        b = np.array(metadata_df['toxicity'][test_indices] >= 0.5).astype('long')
        print((a*b*test_preds.numpy()).sum(), (a*(1-b)*(1-test_preds).numpy()).sum())


if __name__ == '__main__':
    main()
