# -*- coding: utf-8 -*-

import argparse
import os
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
# import sklearn.metrics as sm
# import pandas as pd
# import sklearn.metrics as sm
import random
import numpy as np

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from torch.utils.tensorboard import SummaryWriter

from resnet import get_model
# from load_corrupted_data import CIFAR10, CIFAR100
# from datasets import CelebA
from data_loader import prepare_data
from arguments import get_arguments


args = get_arguments()
use_cuda = True
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")


print()
print(args)

def build_model():
    
    model = get_model(args)
        
    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True

    return model

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def adjust_learning_rate(optimizer, epoch):
    lr = args.lr
    if args.lr_decay is not None:
        for decay_epoch in args.lr_decay:
            lr *= (0.1 ** int(epoch >= decay_epoch))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
        
        
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


        
class GroupEMA:
    
    def __init__(self, size, step_size=0.01):
        self.step_size = step_size
        # self.exp_avg_loss = torch.zeros(size).cuda()
        # self.exp_avg_initialized = torch.zeros(size).byte().cuda()
        self.group_weights = torch.ones(size).cuda() / size
        
        
    def update(self, group_loss, group_count):
        
        # prev_weights = (1 - self.gamma * (group_count > 0).float()) * (self.exp_avg_initialized > 0).float()
        # curr_weights = 1 - prev_weights
        # self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss * curr_weights
        # self.exp_avg_initialized = (self.exp_avg_initialized > 0) + (group_count > 0)
        
        self.group_weights = self.group_weights * torch.exp(self.step_size * group_loss.data)
        self.group_weights = self.group_weights / self.group_weights.sum()
        
        weighted_loss = group_loss @ self.group_weights
        
        return weighted_loss


def test(model, test_loader, writer, epoch):
    model.eval()
    correct = 0
    test_loss = 0
    
    ys = []
    bs = []
    test_losses = []
    corrects = []
    corrects_bias = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs, targets = batch['x'].to(device), batch['y'].to(device)
            y_hat = model(inputs)
            test_loss = F.cross_entropy(y_hat, targets, reduction='none')
            _, predicted = y_hat.cpu().max(1)
            if args.num_classes == 2:
                correct = predicted.eq(batch['y'])
                correct_bias = predicted.eq(batch['a'])
            elif args.num_classes == 4:
                correct = (predicted//2).eq(batch['y'])
                correct_bias = (predicted%2).eq(batch['a'])
            
            test_losses.append(test_loss.cpu())
            corrects.append(correct)
            corrects_bias.append(correct_bias)
            ys.append(batch['y'])
            bs.append(batch['a'])
            
    test_losses = torch.cat(test_losses)
    corrects = torch.cat(corrects)
    corrects_bias = torch.cat(corrects_bias)
    ys = torch.cat(ys)
    bs = torch.cat(bs)
    
    group = ys*2 + bs
    group_indices = dict()
    for i in range(4):
        group_indices[i] = np.where(group == i)[0]
    
    print('')
    worst_accuracy = 100
    worst_bias_accuracy = 100
    for i in range(4):
        loss = test_losses[group_indices[i]].mean().item()
        correct = corrects[group_indices[i]].sum().item()
        correct_bias = corrects_bias[group_indices[i]].sum().item()
        accuracy = 100. * correct / len(group_indices[i])
        accuracy_bias = 100. * correct_bias / len(group_indices[i])
        if accuracy < worst_accuracy:
            worst_accuracy = accuracy
            worst_accuracy_bias = accuracy_bias
            worst_loss = loss
            worst_correct = correct
            worst_correct_bias = correct_bias
            worst_len = len(group_indices[i])
        if accuracy_bias < worst_bias_accuracy:
            worst_bias_accuracy = accuracy_bias
            worst_bias_correct = correct_bias
            worst_bias_len = len(group_indices[i])
        
        writer.add_scalar(f'valid/accuracy_group{i}', accuracy, epoch)
        writer.add_scalar(f'valid/accuracy_bias_group{i}', accuracy_bias, epoch)
        print(f'Test set - group {i}: Average loss: {loss:.4f}, Accuracy: {correct}/{len(group_indices[i])}({accuracy:.4f}%)')
        print(f'Test set - group {i}: Bias Accuracy: {correct_bias}/{len(group_indices[i])}({accuracy_bias:.4f}%)\n')
        
    writer.add_scalar(f'valid/accuracy_worst_group', worst_accuracy, epoch)
    writer.add_scalar(f'valid/bias_accuracy_worst_group', worst_bias_accuracy, epoch)
    print(f'Test set - worst group: Average loss: {worst_loss:.4f}, Accuracy: {worst_correct}/{worst_len}({worst_accuracy:.4f}%)\n')
    print(f'Test set - worst group: Bias Accuracy: {worst_correct_bias}/{worst_len}({worst_accuracy_bias:.4f}%)\n')
    print(f'Test set - worst bias group: Bias Accuracy: {worst_bias_correct}/{worst_bias_len}({worst_bias_accuracy:.4f}%)\n')
    
    loss = test_losses.mean().item()
    correct = corrects.sum().item()
    correct_bias = corrects_bias.sum().item()
    accuracy = 100. * corrects.sum().item() / len(test_loader.dataset)
    accuracy_bias = 100. * corrects_bias.sum().item() / len(test_loader.dataset)
    writer.add_scalar(f'valid/accuracy_average', accuracy, epoch)
    writer.add_scalar(f'valid/accuracy_bias_average', accuracy_bias, epoch)
    print(f'Test set: Average loss: {loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.4f}%)\n')
    print(f'Test set: Bias Accuracy: {correct_bias}/{len(test_loader.dataset)} ({accuracy_bias:.4f}%)\n')

    return worst_accuracy




def train(train_grouped_loaders, model, optimizer, epoch):
    print('\nEpoch: %d' % epoch)
    
    train_loss = 0
    # bias_loss = 0
    supcon_loss_0 = 0
    supcon_loss_1 = 0
    supcon_loss_2 = 0
    supcon_loss_3 = 0
    
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    num_groups = len(train_grouped_loaders)
    
    # define iter
    group_loader_iters = list()
    for g in range(num_groups):
        group_loader_iter = iter(train_grouped_loaders[g])
        group_loader_iters.append(group_loader_iter)
        
    model.train()
    optimizer.zero_grad()

    for batch_idx in range(args.val_iteration):
        
        batches = list()
        for g in range(num_groups):
            try:
                batch = next(group_loader_iters[g])
                batches.append(batch)
            except:
                group_loader_iters[g] = iter(train_grouped_loaders[g])
                batch = next(group_loader_iters[g])
                batches.append(batch)
                
        assert len(batches) == num_groups
        
        inputs = list()
        targets = list()
        biases = list()
        for g in range(num_groups):
            batch = batches[g]
            inputs.append(batch['x'])
            targets.append(batch['y'])
            biases.append(batch['a'])
        inputs = torch.cat(inputs).to(device)
        targets = torch.cat(targets).to(device)
        biases = torch.cat(biases).to(device)
        
        
        y_hat, feat = model(inputs, return_feat=True)
        cost_y = criterion(y_hat, targets)
        
        feat = feat / (feat.norm(dim=1, keepdim=True) + 1e-8)
        
        bs = feat.size(0) // num_groups
        sim_01 = torch.mm(feat[:bs], feat[bs:bs*2].t())
        sim_02 = torch.mm(feat[:bs], feat[bs*2:bs*3].t())
        sim_13 = torch.mm(feat[bs:bs*2], feat[bs*3:].t())
        sim_23 = torch.mm(feat[bs*2:bs*3], feat[bs*3:].t())
        
        sim_01 = torch.exp(sim_01 / args.temperature_supcon)
        sim_02 = torch.exp(sim_02 / args.temperature_supcon)
        sim_13 = torch.exp(sim_13 / args.temperature_supcon)
        sim_23 = torch.exp(sim_23 / args.temperature_supcon)
        
        denom_0 = torch.sum(sim_01, dim=1, keepdim=True) + torch.sum(sim_02, dim=1, keepdim=True)
        denom_1 = torch.sum(sim_01.t(), dim=1, keepdim=True) + torch.sum(sim_13, dim=1, keepdim=True)
        denom_2 = torch.sum(sim_02.t(), dim=1, keepdim=True) + torch.sum(sim_23, dim=1, keepdim=True)
        denom_3 = torch.sum(sim_13.t(), dim=1, keepdim=True) + torch.sum(sim_23.t(), dim=1, keepdim=True)
        
        loss_0 = -torch.log(sim_01 / (denom_0 + 1e-8) + 1e-8)
        loss_1 = -torch.log(sim_01.t() / (denom_1 + 1e-8) + 1e-8)
        loss_2 = -torch.log(sim_23 / (denom_2 + 1e-8) + 1e-8)
        loss_3 = -torch.log(sim_23.t() / (denom_3 + 1e-8) + 1e-8)
        
        loss_supcon = (loss_0.mean() + loss_1.mean() + loss_2.mean() + loss_3.mean()) / 4
            
        # cost_a = criterion(a_hat, biases)
        prec_train = accuracy(y_hat.data, targets.data, topk=(1,))[0]
        prec_bias_train = accuracy(y_hat.data, biases.data, topk=(1,))[0]
        
        # loss = cost_y + cost_a
        # loss = cost_y
        loss = loss_supcon * args.lambda_supcon + cost_y.mean() * (1 - args.lambda_supcon)
        
        # optimizer.zero_grad()
        loss.backward()
        loss /= args.grad_accumulation
        if (batch_idx+1)%args.grad_accumulation == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        train_loss += cost_y.mean().item()
        # bias_loss += cost_a.item()
        supcon_loss_0 += loss_0.mean().item()
        supcon_loss_1 += loss_1.mean().item()
        supcon_loss_2 += loss_2.mean().item()
        supcon_loss_3 += loss_3.mean().item()

        if (batch_idx + 1) % 50 == 0:
            print('Epoch: [%d/%d]\t'
                  'Iters: [%d/%d]\t'
                  'Loss: %.4f\t'
                  'Prec@1 %.2f\t'
                  'Prec_bias@1 %.2f' % (
                      (epoch + 1), args.epochs, batch_idx + 1, args.val_iteration/args.batch_size, (train_loss / (batch_idx + 1)),
                      prec_train, prec_bias_train))
            
            # print(sim_01.mean(), sim_02.mean(), sim_13.mean(), sim_23.mean())
                
    return train_loss/(batch_idx+1), supcon_loss_0/(batch_idx+1), supcon_loss_1/(batch_idx+1), supcon_loss_2/(batch_idx+1), supcon_loss_3/(batch_idx+1)


train_grouped_loaders, valid_loader, test_loader = prepare_data(args)
# create model
model = build_model()

if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
else:
    raise NotImplementedError

# weight_ema = EMA(size=len(train_loader.dataset), alpha=0.7)
group_weight_ema = GroupEMA(size=4, step_size=0.01)

ckpt_dir = os.path.join('results', args.dataset, args.name)
log_dir = os.path.join('summary', args.dataset, args.name)
    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
    
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    
writer = SummaryWriter(log_dir)

def main():
    best_acc = 0
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch)
        train_loss, supcon_loss_0, supcon_loss_1, supcon_loss_2, supcon_loss_3  = train(train_grouped_loaders, model, optimizer, epoch)
        writer.add_scalar(f'train/train_loss', train_loss, epoch)
        # writer.add_scalar(f'train/bias_loss', bias_loss, epoch)
        writer.add_scalar(f'train/supcon_loss_group0', supcon_loss_0, epoch)
        writer.add_scalar(f'train/supcon_loss_group1', supcon_loss_1, epoch)
        writer.add_scalar(f'train/supcon_loss_group2', supcon_loss_2, epoch)
        writer.add_scalar(f'train/supcon_loss_group3', supcon_loss_3, epoch)
        
        valid_acc = test(model, valid_loader, writer, epoch)
        
#         state_dict = {'model': model.state_dict(), 'group_weights': group_weight_ema.group_weights}
#         torch.save(state_dict, os.path.join(ckpt_dir, f'epoch_{epoch+1}.pth'))
        
        if valid_acc >= best_acc:
            best_acc = valid_acc
            state_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
            torch.save(state_dict, os.path.join(ckpt_dir, f'epoch_{epoch+1}.pth'))

    print('best accuracy:', best_acc)


if __name__ == '__main__':
    main()

