import argparse
import datetime
import logging
import os
import time
from pathlib import Path
from torch import nn
import torch.nn.functional as F

import numpy as np
import torch
from torch import nn
from datasets.corrupted_cifar import get_cifar
from models.resnet import ResNet18
from utils.logging import set_logging
from utils.utils import (
    AverageMeter,
    MultiDimAverageMeter,
    accuracy,
    load_model,
    pretty_dict,
    save_model,
    set_seed,
)
from tqdm import tqdm



def parse_option():
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_name", type=str, default="test")
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--dataset", type=str, default="corrupted-cifar10_")
    parser.add_argument("--data_dir", type=str, default="./data")

    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--seed", type=int, default=1)

    parser.add_argument("--bs", type=int, default=256, help="batch_size")
    parser.add_argument("--cbs", type=int, default=256, help="batch_size")

    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--ratio", type=int, default=10)
    parser.add_argument("--aug", type=int, default=1)

    opt = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)

    return opt


def set_model(opt):
    model = ResNet18(num_classes=10).cuda()

    criterion1 = (
        nn.CrossEntropyLoss()
    )  
    protected_net = nn.Linear(10,512)
    protected_net.cuda()
    return model, criterion1, protected_net


def train(
    train_loader,
    cont_train_loader,
    model,
    criterion,
    optimizer,
    optimizer_fc,
    protected_net,
    opt,
    scaler,
):
    model.train()
    avg_loss = AverageMeter()
    avg_clloss = AverageMeter()
    avg_miloss = AverageMeter()
    total = 0
    loss_ce_fn = nn.CrossEntropyLoss()
    for idx, (images, labels, biases) in enumerate(tqdm(train_loader)):
        
        bsz = labels.shape[0]
        labels, biases = labels.cuda(), biases.cuda()
        images = images.cuda()


        with torch.cuda.amp.autocast(scaler is not None):
            with torch.no_grad():
                biases = F.one_hot(biases, num_classes=10).float()
                pr_feat = protected_net(biases)
            logits, feats = model.concat_forward(images,pr_feat)
            
            loss_ce = loss_ce_fn(logits, labels)
            loss =  loss_ce
            avg_loss.update(loss.item(), bsz)
            avg_clloss.update(0, bsz)
            avg_miloss.update(0, bsz)

            optimizer.zero_grad()


            loss.backward()
            optimizer.step()

            total += bsz
    return avg_loss.avg, avg_clloss.avg, avg_miloss.avg


def validate(val_loader, model):
    model.eval()

    top1 = AverageMeter()
    attrwise_acc_meter = MultiDimAverageMeter(dims=(10, 10))

    with torch.no_grad():
        for idx, (images, labels, biases) in enumerate(tqdm(val_loader)):
            images, labels, biases = images.cuda(), labels.cuda(), biases.cuda()
            bsz = labels.shape[0]

            output, _ = model(images)
            preds = output.data.max(1, keepdim=True)[1].squeeze(1)

            (acc1,) = accuracy(output, labels, topk=(1,))
            top1.update(acc1[0], bsz)
            corrects = (preds == labels).long()
            attrwise_acc_meter.add(
                corrects.cpu(), torch.stack([labels.cpu(), biases.cpu()], dim=1)
            )

    return top1.avg, attrwise_acc_meter.get_mean()


def main():
    opt = parse_option()

    exp_name = f"badd-cifar10_{opt.dataset}-{opt.exp_name}-lr{opt.lr}-bs{opt.bs}-seed{opt.seed}"
    opt.exp_name = exp_name

    output_dir = f"results/{exp_name}"
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_logging(exp_name, "INFO", str(save_path))
    logging.info(f"Set seed: {opt.seed}")
    set_seed(opt.seed)
    logging.info(f"save_path: {save_path}")

    np.set_printoptions(precision=3)
    torch.set_printoptions(precision=3)

    train_loader, val_loader = get_cifar(opt, twocrop=False, aug=True)
    cont_train_loader, _ = get_cifar(opt, twocrop=True, aug=True)

    val_loaders = {}

    val_loaders["test"] = val_loader

    model, criterion, protected_net = set_model(opt)

    decay_epochs = [opt.epochs // 3, opt.epochs * 2 // 3]

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs)
    
    logging.info(f"decay_epochs: {decay_epochs}")

    (save_path / "checkpoints").mkdir(parents=True, exist_ok=True)

    best_accs = {"test": 0}
    best_epochs = {"test": 0}
    best_stats = {}
    start_time = time.time()
    for epoch in range(1, opt.epochs + 1):
        logging.info(
            f"[{epoch} / {opt.epochs}] Learning rate: {scheduler.get_last_lr()[0]}"
        )
        loss, cllossp, milossp = train(
            train_loader,
            cont_train_loader,
            model,
            criterion,
            optimizer,
            optimizer,
            protected_net,
            opt,
            scaler=None,
        )
        logging.info(
            f"[{epoch} / {opt.epochs}] Loss: {loss}  Loss CE: {cllossp}  Loss MI: {milossp}"
        )

        scheduler.step()
        stats = pretty_dict(epoch=epoch)
        for key, val_loader in val_loaders.items():
            accs, valid_attrwise_accs = validate(val_loader, model)

            stats[f"{key}/acc"] = accs.item()
            stats[f"{key}/acc_unbiased"] = torch.mean(valid_attrwise_accs).item() * 100
            eye_tsr = torch.eye(10)
            stats[f"{key}/acc_skew"] = (
                valid_attrwise_accs[eye_tsr == 0.0].mean().item() * 100
            )
            stats[f"{key}/acc_align"] = (
                valid_attrwise_accs[eye_tsr > 0.0].mean().item() * 100
            )

        logging.info(f"[{epoch} / {opt.epochs}] {valid_attrwise_accs} {stats}")
        for tag in val_loaders.keys():
            if stats[f"{tag}/acc_unbiased"] > best_accs[tag]:
                best_accs[tag] = stats[f"{tag}/acc_unbiased"]
                best_epochs[tag] = epoch
                best_stats[tag] = pretty_dict(
                    **{f"best_{tag}_{k}": v for k, v in stats.items()}
                )

                save_file = save_path / "checkpoints" / f"best_{tag}.pth"
                save_model(model, optimizer, opt, epoch, save_file)
            logging.info(
                f"[{epoch} / {opt.epochs}] best {tag} accuracy: {best_accs[tag]:.3f} at epoch {best_epochs[tag]} \n best_stats: {best_stats[tag]}"
            )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logging.info(f"Total training time: {total_time_str}")

    save_file = save_path / "checkpoints" / f"last.pth"
    save_model(model, optimizer, opt, opt.epochs, save_file)


if __name__ == "__main__":
    main()
