# -*- coding: utf-8 -*-

import argparse
import os
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
# import sklearn.metrics as sm
# import pandas as pd
# import sklearn.metrics as sm
import random
import numpy as np

from pytorch_transformers import AdamW, WarmupLinearSchedule

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from torch.utils.tensorboard import SummaryWriter

from resnet import get_model
# from load_corrupted_data import CIFAR10, CIFAR100
# from datasets import CelebA
from data_loader import prepare_data
from arguments import get_arguments

torch.multiprocessing.set_sharing_strategy('file_system')

args = get_arguments()
use_cuda = True
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")


print()
print(args)

def build_model():
    
    model = get_model(args)
        
    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True

    return model

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def adjust_learning_rate(optimizer, epoch):
    lr = args.lr
    if args.lr_decay is not None:
        for decay_epoch in args.lr_decay:
            lr *= (0.1 ** int(epoch >= decay_epoch))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
        
        
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


        
class GroupEMA:
    
    def __init__(self, size, step_size=0.01):
        self.step_size = step_size
        # self.exp_avg_loss = torch.zeros(size).cuda()
        # self.exp_avg_initialized = torch.zeros(size).byte().cuda()
        self.group_weights = torch.ones(size).cuda() / size
        
        
    def update(self, group_loss, group_count):
        
        # prev_weights = (1 - self.gamma * (group_count > 0).float()) * (self.exp_avg_initialized > 0).float()
        # curr_weights = 1 - prev_weights
        # self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss * curr_weights
        # self.exp_avg_initialized = (self.exp_avg_initialized > 0) + (group_count > 0)
        
        self.group_weights = self.group_weights * torch.exp(self.step_size * group_loss.data)
        self.group_weights = self.group_weights / self.group_weights.sum()
        
        weighted_loss = group_loss @ self.group_weights
        
        return weighted_loss


def test(model, test_loader, writer, epoch):
    model.eval()
    correct = 0
    test_loss = 0
    
    ys = []
    bs = []
    test_losses = []
    corrects = []
    corrects_bias = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs, targets = batch['x'].to(device), batch['y'].to(device)
            if args.model.startswith('bert'):
                input_ids = inputs[:, :, 0]
                input_masks = inputs[:, :, 1]
                segment_ids = inputs[:, :, 2]
                y_hat = model(
                    input_ids=input_ids,
                    attention_mask=input_masks,
                    token_type_ids=segment_ids,
                    labels=targets,
                )[1]  # [1] returns logits
            else:
                # outputs.shape: (batch_size, num_classes)
                y_hat = model(inputs)
            test_loss = F.cross_entropy(y_hat, targets, reduction='none')
            _, predicted = y_hat.cpu().max(1)
            if args.num_classes == 2:
                correct = predicted.eq(batch['y'])
                correct_bias = predicted.eq(batch['a'])
            elif args.num_classes == 3:
                correct = predicted.eq(batch['y'])
                correct_bias = predicted.eq(batch['a'])
            elif args.num_classes == 4:
                correct = (predicted//2).eq(batch['y'])
                correct_bias = (predicted%2).eq(batch['a'])
            
            test_losses.append(test_loss.cpu())
            corrects.append(correct)
            corrects_bias.append(correct_bias)
            ys.append(batch['y'])
            bs.append(batch['a'])
            
    test_losses = torch.cat(test_losses)
    corrects = torch.cat(corrects)
    corrects_bias = torch.cat(corrects_bias)
    ys = torch.cat(ys)
    bs = torch.cat(bs)
    
    num_groups = 6 if 'mnli' in args.dataset else 4
    group = ys*2 + bs
    group_indices = dict()
    for i in range(num_groups):
        group_indices[i] = np.where(group == i)[0]
    
    print('')
    worst_accuracy = 100
    worst_bias_accuracy = 100
    for i in range(num_groups):
        loss = test_losses[group_indices[i]].mean().item()
        correct = corrects[group_indices[i]].sum().item()
        correct_bias = corrects_bias[group_indices[i]].sum().item()
        accuracy = 100. * correct / len(group_indices[i])
        accuracy_bias = 100. * correct_bias / len(group_indices[i])
        if accuracy < worst_accuracy:
            worst_accuracy = accuracy
            worst_accuracy_bias = accuracy_bias
            worst_loss = loss
            worst_correct = correct
            worst_correct_bias = correct_bias
            worst_len = len(group_indices[i])
        if accuracy_bias < worst_bias_accuracy:
            worst_bias_accuracy = accuracy_bias
            worst_bias_correct = correct_bias
            worst_bias_len = len(group_indices[i])
        
        writer.add_scalar(f'valid/accuracy_group{i}', accuracy, epoch)
        writer.add_scalar(f'valid/accuracy_bias_group{i}', accuracy_bias, epoch)
        print(f'Test set - group {i}: Average loss: {loss:.4f}, Accuracy: {correct}/{len(group_indices[i])}({accuracy:.4f}%)')
        print(f'Test set - group {i}: Bias Accuracy: {correct_bias}/{len(group_indices[i])}({accuracy_bias:.4f}%)\n')
        
    writer.add_scalar(f'valid/accuracy_worst_group', worst_accuracy, epoch)
    writer.add_scalar(f'valid/bias_accuracy_worst_group', worst_bias_accuracy, epoch)
    print(f'Test set - worst group: Average loss: {worst_loss:.4f}, Accuracy: {worst_correct}/{worst_len}({worst_accuracy:.4f}%)\n')
    print(f'Test set - worst group: Bias Accuracy: {worst_correct_bias}/{worst_len}({worst_accuracy_bias:.4f}%)\n')
    print(f'Test set - worst bias group: Bias Accuracy: {worst_bias_correct}/{worst_bias_len}({worst_bias_accuracy:.4f}%)\n')
    
    loss = test_losses.mean().item()
    correct = corrects.sum().item()
    correct_bias = corrects_bias.sum().item()
    accuracy = 100. * corrects.sum().item() / len(test_loader.dataset)
    accuracy_bias = 100. * corrects_bias.sum().item() / len(test_loader.dataset)
    writer.add_scalar(f'valid/accuracy_average', accuracy, epoch)
    writer.add_scalar(f'valid/accuracy_bias_average', accuracy_bias, epoch)
    print(f'Test set: Average loss: {loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.4f}%)\n')
    print(f'Test set: Bias Accuracy: {correct_bias}/{len(test_loader.dataset)} ({accuracy_bias:.4f}%)\n')

    return worst_accuracy




def train(train_loader, model, optimizer, epoch):
    print('\nEpoch: %d' % epoch)
    
    train_loss = 0
    # bias_loss = 0
    
    criterion = nn.CrossEntropyLoss(reduction='none')
    num_groups = 6 if 'mnli' in args.dataset else 4

    # train_meta_loader_iter = iter(train_meta_loader)
    for batch_idx, batch in enumerate(train_loader):
        model.train()
        inputs, targets, biases = batch['x'].to(device), batch['y'].to(device), batch['a'].to(device)
        index = batch['dataset_index']
        
        if args.model.startswith('bert'):
            input_ids = inputs[:, :, 0]
            input_masks = inputs[:, :, 1]
            segment_ids = inputs[:, :, 2]
            y_hat = model(
                input_ids=input_ids,
                attention_mask=input_masks,
                token_type_ids=segment_ids,
                labels=targets,
            )[1]  # [1] returns logits
        else:
            # outputs.shape: (batch_size, num_classes)
            y_hat = model(inputs)
        cost_y = criterion(y_hat, targets)
        # cost_a = criterion(a_hat, biases)
        prec_train = accuracy(y_hat.data, targets.data, topk=(1,))[0]
        prec_bias_train = accuracy(y_hat.data, biases.data, topk=(1,))[0]
        
        # 
        group_idx = targets*2 + biases
        group_map = (group_idx == torch.arange(num_groups).unsqueeze(1).long().cuda()).float()
        group_count = group_map.sum(1)
        group_denom = group_count + (group_count == 0).float()  # avoid nans
        group_loss = (group_map @ cost_y.view(-1)) / group_denom
        
        weighted_loss = group_weight_ema.update(group_loss, group_count)
        
        # loss = cost_y + cost_a
        # loss = cost_y
        loss = weighted_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += cost_y.mean().item()
        # bias_loss += cost_a.item()

        if (batch_idx + 1) % 50 == 0:
            print('Epoch: [%d/%d]\t'
                  'Iters: [%d/%d]\t'
                  'Loss: %.4f\t'
                  'Prec@1 %.2f\t'
                  'Prec_bias@1 %.2f' % (
                      (epoch + 1), args.epochs, batch_idx + 1, len(train_loader.dataset)/args.batch_size, (train_loss / (batch_idx + 1)),
                      prec_train, prec_bias_train))
                
    return train_loss/(batch_idx+1)


train_loader, _, valid_loader, test_loader = prepare_data(args)
# create model
model = build_model()

if args.model == 'bert':
    
    args.max_grad_norm = 1.0
    args.adam_epsilon = 1e-8
    args.warmup_steps = 0
    
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.lr,
                      eps=args.adam_epsilon)
    t_total = len(train_loader) * args.epochs
    print(f"\nt_total is {t_total}\n")
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

else:


    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    else:
        raise NotImplementedError
        
    scheduler = None

# weight_ema = EMA(size=len(train_loader.dataset), alpha=0.7)
num_groups = 6 if 'mnli' in args.dataset else 4
group_weight_ema = GroupEMA(size=num_groups, step_size=0.01)

ckpt_dir = os.path.join('results', args.dataset, args.name)
log_dir = os.path.join('summary', args.dataset, args.name)
    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
    
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    
writer = SummaryWriter(log_dir)

def main():
    best_acc = 0
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch)
        train_loss = train(train_loader, model, optimizer, epoch)
        writer.add_scalar(f'train/train_loss', train_loss, epoch)
        # writer.add_scalar(f'train/bias_loss', bias_loss, epoch)
        
        valid_acc = test(model, valid_loader, writer, epoch)
        
#         state_dict = {'model': model.state_dict(), 'group_weights': group_weight_ema.group_weights}
#         torch.save(state_dict, os.path.join(ckpt_dir, f'epoch_{epoch+1}.pth'))
        
        if valid_acc >= best_acc:
            best_acc = valid_acc
            state_dict = {'model': model.state_dict(), 'group_weights': group_weight_ema.group_weights}
            torch.save(state_dict, os.path.join(ckpt_dir, f'epoch_{epoch+1}.pth'))
        elif 'mnli' in args.dataset or args.dataset == 'jigsaw':
            state_dict = {'model': model.state_dict(), 'group_weights': group_weight_ema.group_weights}
            torch.save(state_dict, os.path.join(ckpt_dir, f'epoch_{epoch+1}.pth'))

    print('best accuracy:', best_acc)


if __name__ == '__main__':
    main()
