import argparse
import datetime
import logging
import os
import time
from pathlib import Path

import numpy as np
import torch
from torch import nn

from debias.datasets.celeba import get_celeba
from debias.networks.resnet import FCResNet18
from debias.utils.logging import set_logging
from debias.utils.utils import (AverageMeter, MultiDimAverageMeter, accuracy,
                                pretty_dict, save_model, set_seed)
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm


def parse_option():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default='test')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--task', type=str, default='blonde')

    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--bs', type=int, default=128, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--ecu', type=int, default=0)
    parser.add_argument('--uw', type=int, default=1)
    parser.add_argument('--rb', type=int, default=0)
    parser.add_argument('--mode', type=str, default='none')

    opt = parser.parse_args()

    return opt

def set_model():
    model = FCResNet18(num_classes=2).cuda()
    pred = nn.Linear(512, 2).cuda()

    models = { 
                'model': model, 
                'pred': pred 
             }

    criterion = { 
                'bin': nn.BCEWithLogitsLoss(reduction='none'),
                'multi': nn.CrossEntropyLoss(reduction='none')
                } 

    return models, criterion

def get_samples_counts(all_labels_nb, all_bias):
    g_idxs = [] 
    g_counts = [] 
    full_idx = np.arange(len(all_bias))
    for i in range(2): 
        for j in range(2): 
            g_idxs.append(full_idx[np.logical_and(all_bias == i, all_labels_nb == j)])
            g_counts.append(len(g_idxs[-1])) 
    return g_idxs, g_counts

def under_sample(all_bias, all_feats, all_labels_nb):

    g_idxs, g_counts = get_samples_counts(all_labels_nb, all_bias)
    min_group = min(g_counts) 

    to_keep_idx_all = [] 
    for _, group_idx in enumerate(g_idxs): 
        to_keep_idx = np.random.choice(group_idx, min_group)
        to_keep_idx_all.extend(to_keep_idx)

    all_feats = all_feats[to_keep_idx_all]
    all_labels_nb = all_labels_nb[to_keep_idx_all]
    all_bias = all_bias[to_keep_idx_all]

    full_idx = np.arange(len(all_feats))
    np.random.shuffle(full_idx)

    all_feats = all_feats[full_idx]
    all_labels_nb = all_labels_nb[full_idx]
    all_bias = all_bias[full_idx]

    g_idxs, g_counts = get_samples_counts(all_labels_nb, all_bias)
    return all_feats, all_labels_nb


def over_sample(all_bias, all_feats, all_labels_nb):

    g_idxs, g_counts = get_samples_counts(all_labels_nb, all_bias)
    max_group = max(g_counts) 

    for idx, group_idx in enumerate(g_idxs): 
        to_add = max_group - len(group_idx)
        to_add_idx = np.random.choice(group_idx, to_add)

        if to_add == 0: 
            continue

        all_feats = np.concatenate((all_feats, all_feats[to_add_idx]), axis=0) 
        all_labels_nb = np.concatenate((all_labels_nb, all_labels_nb[to_add_idx]), axis=0) 
        all_bias = np.concatenate((all_bias, all_bias[to_add_idx]), axis=0) 

    full_idx = np.arange(len(all_feats))
    np.random.shuffle(full_idx)

    all_feats = all_feats[full_idx]
    all_labels_nb = all_labels_nb[full_idx]
    all_bias = all_bias[full_idx]

    g_idxs, g_counts = get_samples_counts(all_labels_nb, all_bias)

    return all_feats, all_labels_nb

def train(train_loader, model, criterion, optimizer_model, opt):
    model['model'].train()
    avg_loss = AverageMeter()

    train_iter = iter(train_loader)
    # total_steps = len(train_iter)
    
    all_outputs = [] 
    all_labels_nb = [] 
    all_gc = [] 
    all_feats = [] 
    all_bias = [] 

    sig = nn.Sigmoid()

    for _, (images, labels, biases, _, labels_nb, gc, gc_imbalance) in enumerate(tqdm(train_iter, ascii=True)):
         
        bsz = labels.shape[0]
        labels, biases = labels.cuda(), biases.cuda()
        gc_imbalance = gc_imbalance.cuda()
    
        images = images.cuda()
        logits, feat = model['model'](images)

        all_outputs.append(sig(logits).cpu().detach().numpy()) 
        all_labels_nb.append(labels_nb.cpu().detach().numpy())
        all_gc.append(gc.numpy())
        all_bias.append(biases.cpu().detach().numpy())
        all_feats.append(feat.cpu().detach().numpy())

        multi = torch.ones_like(labels) 

        multi[labels == -1] = 0 
        labels[labels == -1] = 0 

        loss = criterion['bin'](logits, labels)
        loss = loss*multi 

        div = torch.sum(multi) 
        loss = torch.sum(loss/div)

        avg_loss.update(loss.item(), bsz)

        optimizer_model.zero_grad()
        loss.backward()
        optimizer_model.step()
    
    all_labels_nb = np.concatenate(all_labels_nb, axis=0) 
    all_gc = np.concatenate(all_gc, axis=0) 
    all_bias = np.concatenate(all_bias, axis=0) 
    all_feats = np.concatenate(all_feats, axis=0) 

    if opt.mode == 'os':
        all_feats, all_labels_nb = over_sample(all_bias, all_feats, all_labels_nb)

    elif opt.mode == 'us':
        all_feats, all_labels_nb = under_sample(all_bias, all_feats, all_labels_nb)

    batch_size = opt.bs 
    num_epochs = opt.epochs
    total_samples = len(all_labels_nb)
    num_batches = total_samples//batch_size 

    model['pred'] = nn.Linear(512, 2).cuda()

    decay_epochs = [opt.epochs // 3, opt.epochs * 2 // 3]
    optimizer =  torch.optim.Adam(model['pred'].parameters(), lr=opt.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epochs, gamma=0.1)

    for _ in tqdm(range(num_epochs), ascii=True):
        all_idx = np.arange(total_samples)
        np.random.shuffle(all_idx)

        all_feats[all_idx] = all_feats 
        all_labels_nb[all_idx] = all_labels_nb
        if opt.mode == 'uw':
            all_gc[all_idx] = all_gc

        for batch_idx in range(num_batches): 
            start = batch_idx * batch_size 
            end = min(total_samples, start + batch_size)

            feats = torch.from_numpy(all_feats[start:end]).cuda()
            labels = torch.from_numpy(all_labels_nb[start:end]).cuda()
            gc = torch.from_numpy(all_gc[start:end]).cuda()

            optimizer.zero_grad() 
            out_lr = model['pred'](feats)

            if opt.mode == 'uw':
                loss = criterion['multi'](out_lr, labels)*gc
                loss = torch.mean(loss)
            else: 
                loss = criterion['multi'](out_lr, labels)
                loss = torch.mean(loss)
                
            loss.backward() 
            optimizer.step()
        
        scheduler.step()

    return avg_loss.avg


def validate(val_loader, model):
    model['model'].eval()
    
    top1 = AverageMeter()
    attrwise_acc_meter = MultiDimAverageMeter(dims=(2, 2))
    
    with torch.no_grad():
        for idx, (images, labels, biases, _, labels_nb, _,_) in enumerate(tqdm(val_loader, ascii=True)):
            images, labels, bias = images.cuda(), labels.cuda(), biases.cuda()
            bsz = labels.shape[0]
            
            output,feats = model['model'](images)
            output = model['pred'](feats).detach().cpu()
            preds = output.data.max(1, keepdim=True)[1].squeeze(1)
            
            acc1, = accuracy(output, labels_nb, topk=(1,))
            top1.update(acc1[0], bsz)

            corrects = (preds == labels_nb).long()
            attrwise_acc_meter.add(corrects.cpu(), 
                                  torch.stack([labels_nb.cpu(), biases.cpu()], dim=1))

    return top1.avg, attrwise_acc_meter.get_mean(), attrwise_acc_meter.get_acc_diff() 


def main():
    opt = parse_option()

    exp_name = f'us_3-celeba_{opt.task}-{opt.exp_name}-lr{opt.lr}-bs{opt.bs}-seed{opt.seed}'
    opt.exp_name = exp_name
    
    if opt.task == "makeup":
        opt.epochs = 40
    elif opt.task == "blonde":
        opt.epochs = 10
    elif opt.task == 'black':
        opt.epochs = 10
    else:
        raise AttributeError()

    output_dir = f'exp_results/{exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_logging(exp_name, 'INFO', str(save_path))
    logging.info(f'Set seed: {opt.seed}')
    set_seed(opt.seed)
    logging.info(f'save_path: {save_path}')

    np.set_printoptions(precision=3)
    torch.set_printoptions(precision=3)
    
    resample_blonde = opt.rb

    root = './data/celeba'
    train_loader = get_celeba(
        root,
        batch_size=opt.bs,
        target_attr=opt.task,
        split='train',
        aug=False,
        under_sample = 'bin',
        resample_blonde=resample_blonde)

    val_loaders = {}
    val_loaders['valid'] = get_celeba(
        root,
        batch_size=256,
        target_attr=opt.task,
        split='train_valid',
        aug=False, 
        resample_blonde=resample_blonde)

    val_loaders['test'] = get_celeba(
        root,
        batch_size=256,
        target_attr=opt.task,
        split='valid',
        aug=False)


    model, criterion = set_model()
    decay_epochs = [opt.epochs // 3, opt.epochs * 2 // 3]

    optimizer = torch.optim.Adam(model['model'].parameters(), lr=opt.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epochs, gamma=0.1)

    logging.info(f"decay_epochs: {decay_epochs}")
    (save_path / 'checkpoints').mkdir(parents=True, exist_ok=True)

    best_accs = {'valid': 0, 'test': 0}
    best_epochs = {'valid': 0, 'test': 0}
    best_stats = {}
    start_time = time.time()
    for epoch in range(1, opt.epochs + 1):
        logging.info(f'[{epoch} / {opt.epochs}] Learning rate: {scheduler.get_last_lr()[0]}')
        loss = train(train_loader, model, criterion, optimizer, opt)
        logging.info(f'[{epoch} / {opt.epochs}] Loss: {loss:.4f}')

        scheduler.step()

        stats = pretty_dict(epoch=epoch)
        for key, val_loader in val_loaders.items():
            accs, valid_attrwise_accs, diff = validate(val_loader, model)

            stats[f'{key}/acc'] = accs.item()
            stats[f'{key}/acc_unbiased'] = torch.mean(valid_attrwise_accs).item() * 100
            stats[f'{key}/diff'] = diff.item() * 100
            if opt.task == 'blonde' and opt.rb == 0:
                eye_tsr = torch.zeros((2,2)) 
                eye_tsr[0, 1] = 1 
                eye_tsr[1, 1] = 1
                
            else: 
                eye_tsr = torch.eye(2)
            
            stats[f'{key}/acc_skew'] = valid_attrwise_accs[eye_tsr > 0.0].mean().item() * 100
            stats[f'{key}/acc_align'] = valid_attrwise_accs[eye_tsr == 0.0].mean().item() * 100

        logging.info(f'[{epoch} / {opt.epochs}] {valid_attrwise_accs} {stats}')
        for tag in val_loaders.keys():
            if stats[f'{tag}/acc_unbiased'] > best_accs[tag]:
                best_accs[tag] = stats[f'{tag}/acc_unbiased']
                best_epochs[tag] = epoch
                best_stats[tag] = pretty_dict(**{f'best_{tag}_{k}': v for k, v in stats.items()})

                save_file = save_path / 'checkpoints' / f'best_{tag}.pth'
                # save_model(model, optimizer, opt, epoch, save_file)
            logging.info(
                f'[{epoch} / {opt.epochs}] best {tag} accuracy: {best_accs[tag]:.3f} at epoch {best_epochs[tag]} \n best_stats: {best_stats[tag]}')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logging.info(f'Total training time: {total_time_str}')

    save_file = save_path / 'checkpoints' / f'last.pth'
    # save_model(model, optimizer, opt, opt.epochs, save_file)


if __name__ == '__main__':
    main()
