import os

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, WeightedRandomSampler

import torchvision

from data_util import Dataset_

import argparse

import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from torch.utils.tensorboard import SummaryWriter
    
    
parser = argparse.ArgumentParser()

parser.add_argument('--name', default='temp', type=str)

# dataset configuration
parser.add_argument('--dataset', default='waterbirds', type=str, help='dataset waterbirds[default]')


# optimization configuration
parser.add_argument('--optimizer', default='sgd', type=str)
parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run')
parser.add_argument('--batch_size', '--batch-size', default=64, type=int, help='mini-batch size')
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum')
parser.add_argument('--lr_decay', nargs='+', type=int)
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int, help='print frequency (default: 10)')

parser.add_argument('--group_dro', default=False, type=bool, help='group_dro')

parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--seed', type=int, default=1)


def prepare_arguments():
    args = parser.parse_args()

    return args


def prepare_dataloader():

    train_dataset = Dataset_(data_name='Waterbirds',
                             data_dir='/home/user/datasets/waterbird_complete95_forest2water2',
                             train=True,
                             crop_long_edge=True,
                             resize_size=128,
                             random_flip=True,
                             normalize=True,
                             hdf5_path=None,
                             load_data_in_memory=False, 
                             return_attr=True)

    eval_dataset = Dataset_(data_name='Waterbirds',
                            data_dir='/home/user/datasets/waterbird_complete95_forest2water2',
                            train=False,
                            crop_long_edge=True,
                            resize_size=128,
                            random_flip=False,
                            hdf5_path=None,
                            normalize=True,
                            load_data_in_memory=False, 
                            return_attr=True)
    
    if args.group_dro:
        
        group = np.zeros(len(train_dataset)).astype('int')
        
        group[np.where(train_dataset.data.targets == 1)[0]] += 1
                
        group_sample_count = np.zeros(2)
        weight = np.zeros(2)
        for g in np.unique(group):
            group_sample_count[g] = len(np.where(group == g)[0])
            weight[g] = 1. / group_sample_count[g]
        # group_sample_count = np.array([len(np.where(group == g)[0]) for g in np.unique(group)])
        # weight = 1. / group_sample_count
        samples_weight = np.array([weight[g] for g in group])
        
        samples_weight = torch.from_numpy(samples_weight)
        train_sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
    
    else:
        train_sampler = None
        
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=64,
                                  shuffle=(train_sampler is None),
                                  pin_memory=True,
                                  num_workers=4,
                                  sampler=train_sampler,
                                  drop_last=True,
                                  persistent_workers=True)

    eval_sampler = None
    eval_dataloader = DataLoader(dataset=eval_dataset,
                                 batch_size=64,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4,
                                 sampler=eval_sampler,
                                 drop_last=False)
    
    return train_dataloader, eval_dataloader


def prepare_model():
    
    model = torchvision.models.resnet50(pretrained=True)
    d = model.fc.in_features
    model.fc = nn.Linear(d, 2)
    
    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True
    
    return model


class GroupEMA:
    
    def __init__(self, size, step_size=0.01):
        self.step_size = step_size
        # self.exp_avg_loss = torch.zeros(size).cuda()
        # self.exp_avg_initialized = torch.zeros(size).byte().cuda()
        self.group_weights = torch.ones(size).cuda() / size
        
        
    def update(self, group_loss, group_count):
        
        # prev_weights = (1 - self.gamma * (group_count > 0).float()) * (self.exp_avg_initialized > 0).float()
        # curr_weights = 1 - prev_weights
        # self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss * curr_weights
        # self.exp_avg_initialized = (self.exp_avg_initialized > 0) + (group_count > 0)
        
        self.group_weights = self.group_weights * torch.exp(self.step_size * group_loss.data)
        self.group_weights = self.group_weights / self.group_weights.sum()
        
        weighted_loss = group_loss @ self.group_weights
        
        return weighted_loss
    
    

def test(model, test_loader, writer, epoch):
    model.eval()
    correct = 0
    test_loss = 0
    
    ys = []
    bs = []
    test_losses = []
    corrects = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            inputs, targets, biases = batch
            inputs, targets, biases = inputs.to(device), targets.to(device), biases.to(device)
            
            y_hat = model(inputs)
            
            test_loss = F.cross_entropy(y_hat, targets, reduction='none').detach().cpu()
            _, predicted = y_hat.cpu().max(1)
            correct = predicted.eq(targets.cpu())
            
            test_losses.append(test_loss)
            corrects.append(correct)
            ys.append(targets.cpu())
            bs.append(biases.cpu())
            
    test_losses = torch.cat(test_losses)
    corrects = torch.cat(corrects)
    ys = torch.cat(ys)
    bs = torch.cat(bs)
    gs = ys*2 + bs
    
    loss = test_losses.mean().item()
    correct = corrects.sum().item()
    accuracy = 100. * corrects.sum().item() / len(test_loader.dataset)
    accuracy_y0 = corrects[np.where(ys == 0)[0]].sum().item() / len(np.where(ys == 0)[0])
    accuracy_y1 = corrects[np.where(ys == 1)[0]].sum().item() / len(np.where(ys == 1)[0])
    accuracy_g0 = corrects[np.where(gs == 0)[0]].sum().item() / len(np.where(gs == 0)[0])
    accuracy_g1 = corrects[np.where(gs == 1)[0]].sum().item() / len(np.where(gs == 1)[0])
    accuracy_g2 = corrects[np.where(gs == 2)[0]].sum().item() / len(np.where(gs == 2)[0])
    accuracy_g3 = corrects[np.where(gs == 3)[0]].sum().item() / len(np.where(gs == 3)[0])
    worst_cls_acc = min(accuracy_y0, accuracy_y1) * 100
    worst_group_acc = min(accuracy_g0, accuracy_g1, accuracy_g2, accuracy_g3) * 100
    # writer.add_scalar(f'valid/accuracy', accuracy, epoch)
    print(f'Test set: Average loss: {loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.4f}%)')
    print(f'Test set: Worst Class Accuracy: {worst_cls_acc:.4f}%, Worst Group Accuracy: {worst_group_acc:.4f}%')

    return worst_cls_acc



def train(train_loader, model, optimizer, epoch, scheduler=None):
    print(f'\nEpoch: {epoch+1}')
    
    train_loss = 0
    
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    model.train()

    for batch_idx, batch in enumerate(train_loader):
        
        inputs, targets, _ = batch
        inputs, targets = inputs.to(device), targets.to(device)
        
        y_hat = model(inputs)
        cost_y = criterion(y_hat, targets)
        prec_train = accuracy(y_hat.data, targets.data, topk=(1,))[0]
        
        if args.group_dro:
            group_idx = targets
            group_map = (group_idx == torch.arange(num_groups).unsqueeze(1).long().cuda()).float()
            group_count = group_map.sum(1)
            group_denom = group_count + (group_count == 0).float()  # avoid nans
            group_loss = (group_map @ cost_y.view(-1)) / group_denom

            weighted_loss = group_weight_ema.update(group_loss, group_count)
            
            loss = weighted_loss
        else:
            loss = cost_y.mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += cost_y.mean().item()

        if (batch_idx + 1) % 50 == 0:
            print('Epoch: [%d/%d]\t'
                  'Iters: [%d/%d]\t'
                  'Loss: %.4f\t'
                  'Prec@1 %.2f\t' % (
                      (epoch + 1), args.epochs, batch_idx + 1, len(train_loader.dataset)/args.batch_size, (train_loss / (batch_idx + 1)),
                      prec_train))
                
    return train_loss/(batch_idx+1)


def adjust_learning_rate(optimizer, epoch):
    lr = args.lr
    if args.lr_decay is not None:
        for decay_epoch in args.lr_decay:
            lr *= (0.1 ** int(epoch >= decay_epoch))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
        
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
        
    
args = prepare_arguments()
    
use_cuda = True
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")

print()
print(args)

train_loader, valid_loader = prepare_dataloader()
print('data loaded')
model = prepare_model()
print('model loaded')
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum, weight_decay=args.weight_decay)
print('optimizer prepared')

if args.group_dro:
    num_groups = 2
    group_weight_ema = GroupEMA(size=num_groups, step_size=0.01)

ckpt_dir = os.path.join('clf_results', args.dataset, args.name)
log_dir = os.path.join('clf_summary', args.dataset, args.name)
    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
    
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    
# writer = SummaryWriter(log_dir)
writer = None

print('logger prepared')

def main():
    best_acc = 0
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch)
        train_loss = train(train_loader, model, optimizer, epoch)
        # writer.add_scalar(f'train/train_loss', train_loss, epoch)
        
        valid_acc = test(model, valid_loader, writer, epoch)
        
        if valid_acc >= best_acc:
            best_acc = valid_acc
            state_dict = {'model': model.state_dict()}
            torch.save(state_dict, os.path.join(ckpt_dir, f'epoch_{epoch+1}.pth'))

    print('best accuracy:', best_acc)
    


if __name__ == '__main__':
    main()

