import argparse
import datetime
import logging
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from debias.datasets.utk_face import get_utk_face
from debias.networks.resnet import FCResNet18
from debias.utils.logging import set_logging
from debias.utils.utils import (AverageMeter, MultiDimAverageMeter, accuracy,
                                pretty_dict, save_model, set_seed)

from tqdm import tqdm 

def parse_option():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default='test')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--task', type=str, default='race')

    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--bs', type=int, default=128, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--lr-w', default=0.1, type=float)

    parser.add_argument('--sampling', default='rank', choices=['threshold', 'rank', 'cls_rank', 'sample', 'uniform'], type=str)
    parser.add_argument('--threshold', type=float, default=0.5)
    parser.add_argument('--keep-ratio', type=float, default=0.5)

    opt = parser.parse_args()

    return opt


def set_model():
    model = FCResNet18().cuda()
    criterion = nn.CrossEntropyLoss()

    return model, criterion

def get_keep_idx(w, cls_idx, args, mode='threshold'):
    # strategy 1: fixed threshold
    if mode == 'threshold':
        keep_idx = (w > args.threshold).nonzero().cpu().squeeze()

    # strategy 2: top k% examples
    elif mode == 'rank':
        keep_examples = round(args.keep_ratio * len(w))
        keep_idx = w.sort(descending=True)[1][:keep_examples].cpu()

    # strategy 3: top k% examples each class
    elif mode == 'cls_rank':
        keep_idx_list = []
        for c in range(10):
            c_idx = cls_idx[c].nonzero().squeeze()
            keep_examples = round(args.keep_ratio * len(c_idx))
            sort_idx = w[c_idx].sort(descending=True)[1]
            keep_idx_list.append(c_idx[sort_idx][:keep_examples])
        keep_idx = torch.cat(keep_idx_list).cpu()

    # strategy 4: sampling according to weights
    elif mode == 'sample':
        keep_idx = torch.bernoulli(w).nonzero().cpu().squeeze()

    # strategy 5: random uniform sampling
    elif mode == 'uniform':
        keep_examples = round(args.keep_ratio * len(w))
        keep_idx = torch.randperm(len(w))[:keep_examples]

    return keep_idx

def train_w(train_loader,epochs, lr, lr_w):
    avg_loss = AverageMeter()

    labels = train_loader.dataset.targets.cuda()
    n_cls = int(labels.max()) + 1
    cls_idx = torch.stack([labels == c for c in range(n_cls)]).float().cuda()
    # total_steps = len(train_iter)

    model = nn.Linear(1, n_cls).cuda()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    weight_param = nn.Parameter(torch.zeros(len(train_loader.dataset)).cuda())
    optimizer_w = torch.optim.SGD([weight_param], lr=lr_w)

    for _ in tqdm(range(epochs)): 
        corrects = 0
        for idx, (images, _, biases, _, labels, _) in enumerate(train_loader):

            x, y = biases.cuda().float(), labels.cuda()

            # class probabilities
            w = torch.sigmoid(weight_param)
            z = w[idx] / w.mean()
            cls_w = cls_idx @ w
            q = cls_w / cls_w.sum()

            # linear classifier
            
            out = model(x[:, None])
            loss_vec = F.cross_entropy(out, y, reduction='none')
            loss = (loss_vec * z).mean()

            corrects += out.max(1)[1].eq(y).sum().item()
            # losses.append(loss.item())
            # avg_loss.update(loss.item(), biases)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            # class weights
            optimizer_w.zero_grad()
            entropy = -(q[y].log() * z).mean()
            loss_w = 1 - loss / entropy
            # print(loss_w)
            loss_w.backward()
            optimizer_w.step()
        acc = 100 * corrects / len(train_loader.dataset)
        # print(acc)


    with torch.no_grad():
        w = torch.sigmoid(weight_param)
        cls_w = cls_idx @ w
        q = cls_w / cls_w.sum()
        rnd_loss = -(q * q.log()).sum().item()
        bias = 1 - loss / rnd_loss

    return w, cls_idx

def train(train_loader, model, criterion, optimizer):
    model.train()
    avg_loss = AverageMeter()

    train_iter = iter(train_loader)
    total_steps = len(train_iter)
    for idx, (images, _, biases, _, labels, _) in enumerate(train_iter):
        bsz = labels.shape[0]
        labels, biases = labels.cuda(), biases.cuda()

        images = images.cuda()
        logits, _ = model(images)

        loss = criterion(logits, labels)

        avg_loss.update(loss.item(), bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return avg_loss.avg

def validate(val_loader, model):
    model.eval()

    top1 = AverageMeter()
    attrwise_acc_meter = MultiDimAverageMeter(dims=(2, 2))

    with torch.no_grad():
        for idx, (images, _, biases, _, labels, _) in enumerate(val_loader):
            images, labels, biases = images.cuda(), labels.cuda(), biases.cuda()
            bsz = labels.shape[0]

            output, _ = model(images)
            preds = output.data.max(1, keepdim=True)[1].squeeze(1)

            acc1, = accuracy(output, labels, topk=(1,))
            top1.update(acc1[0], bsz)

            corrects = (preds == labels).long()
            attrwise_acc_meter.add(corrects.cpu(), torch.stack([labels.cpu(), biases.cpu()], dim=1))

    return top1.avg, attrwise_acc_meter.get_mean(), attrwise_acc_meter.get_acc_diff() 


def main():
    opt = parse_option()

    exp_name = f'ce-utk_face_{opt.task}-{opt.exp_name}-lr{opt.lr}-bs{opt.bs}-seed{opt.seed}'
    opt.exp_name = exp_name

    output_dir = f'exp_results/{exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_logging(exp_name, 'INFO', str(save_path))
    logging.info(f'Set seed: {opt.seed}')
    set_seed(opt.seed)
    logging.info(f'save_path: {save_path}')

    np.set_printoptions(precision=3)
    torch.set_printoptions(precision=3)

    root = './data/utk_face'
    train_loader = get_utk_face(
        root,
        batch_size=opt.bs,
        bias_attr=opt.task,
        split='train',
        aug=False, 
        repair=True)

    val_loaders = {}
    val_loaders['valid'] = get_utk_face(
        root,
        batch_size=256,
        bias_attr=opt.task,
        split='valid',
        aug=False)

    val_loaders['test'] = get_utk_face(
        root,
        batch_size=256,
        bias_attr=opt.task,
        split='test',
        aug=False)

    model, criterion = set_model()


    w, cls_idx = train_w(train_loader, 200, opt.lr, opt.lr_w)

    train_loader = get_utk_face(
        root,
        batch_size=opt.bs,
        bias_attr=opt.task,
        split='train',
        aug=False, )
    
    pos,neg,_,_ = train_loader.dataset.count_pos_neg(train_loader.dataset.targets, train_loader.dataset.bias_targets)
    print(pos)
    print(neg)

    keep_idx = get_keep_idx(w, cls_idx, opt, mode=opt.sampling)
    train_loader.dataset.set_to_keep(keep_idx)
    pos,neg,_,_ = train_loader.dataset.count_pos_neg(train_loader.dataset.targets, train_loader.dataset.bias_targets)
    
    print(pos)
    print(neg)
    # quit()
    # keep_idx_test = keep_idx[keep_idx >= len(train_loader.target_attr)] - len(train_loader.target_attr)

    decay_epochs = [opt.epochs // 3, opt.epochs * 2 // 3]

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epochs, gamma=0.1)
    logging.info(f"decay_epochs: {decay_epochs}")

    (save_path / 'checkpoints').mkdir(parents=True, exist_ok=True)

    best_accs = {'valid': 0, 'test': 0}
    best_epochs = {'valid': 0, 'test': 0}
    best_stats = {}
    start_time = time.time()
    for epoch in range(1, opt.epochs + 1):
        logging.info(f'[{epoch} / {opt.epochs}] Learning rate: {scheduler.get_last_lr()[0]}')
        loss = train(train_loader, model, criterion, optimizer)
        logging.info(f'[{epoch} / {opt.epochs}] Loss: {loss:.4f}')

        scheduler.step()

        stats = pretty_dict(epoch=epoch)
        for key, val_loader in val_loaders.items():
            accs, valid_attrwise_accs, diff = validate(val_loader, model)

            stats[f'{key}/acc'] = accs.item()
            stats[f'{key}/acc_unbiased'] = torch.mean(valid_attrwise_accs).item() * 100
            stats[f'{key}/diff'] = diff.item() * 100
            eye_tsr = torch.eye(2)
            stats[f'{key}/acc_skew'] = valid_attrwise_accs[eye_tsr == 0.0].mean().item() * 100
            stats[f'{key}/acc_align'] = valid_attrwise_accs[eye_tsr > 0.0].mean().item() * 100

        logging.info(f'[{epoch} / {opt.epochs}] {valid_attrwise_accs} {stats}')
        for tag in val_loaders.keys():
            if stats[f'{tag}/acc_unbiased'] > best_accs[tag]:
                best_accs[tag] = stats[f'{tag}/acc_unbiased']
                best_epochs[tag] = epoch
                best_stats[tag] = pretty_dict(**{f'best_{tag}_{k}': v for k, v in stats.items()})

                save_file = save_path / 'checkpoints' / f'best_{tag}.pth'
                save_model(model, optimizer, opt, epoch, save_file)
            logging.info(
                f'[{epoch} / {opt.epochs}] best {tag} accuracy: {best_accs[tag]:.3f} at epoch {best_epochs[tag]} \n best_stats: {best_stats[tag]}')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logging.info(f'Total training time: {total_time_str}')

    save_file = save_path / 'checkpoints' / f'last.pth'
    save_model(model, optimizer, opt, opt.epochs, save_file)


if __name__ == '__main__':
    main()
