import datetime
import math
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchvision
import argparse
import data
import models
import losses
import time
import wandb
import torch.utils.tensorboard
import end

from torchvision import transforms
from torchvision import datasets
from util import AverageMeter, accuracy, ensure_dir, set_seed, arg2bool, warmup_learning_rate, pretty_dict
from lars import LARS
from main_infonce import load_model, load_optimizer
from data.imagenet9 import get_imagenet


def parse_arguments():
    parser = argparse.ArgumentParser(description="Contrastive debiasing",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--device', type=str, help='torch device', default='cuda')
    parser.add_argument('--print_freq', type=int, help='print frequency', default=10)
    parser.add_argument('--trial', type=int, help='random seed / trial id', default=0)
    parser.add_argument('--log_dir', type=str, help='tensorboard log dir', default='logs')

    parser.add_argument('--data_dir', type=str, help='path of data dir', required=True)
    parser.add_argument('--batch_size', type=int, help='batch size', default=256)

    parser.add_argument('--epochs', type=int, help='number of epochs', default=120)
    parser.add_argument('--lr', type=float, help='learning rate', default=0.01)
    parser.add_argument('--warm', type=arg2bool, help='warmup lr', default=False)
    parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step', 'none'], default='cosine')
    parser.add_argument('--lr_decay_epochs', type=str, help='steps of lr decay (list)', default="100,150")
    parser.add_argument('--optimizer', type=str, help="optimizer (adam, sgd or lars)", choices=["adam", "sgd", "lars"], default="sgd")
    parser.add_argument('--momentum', type=float, help='momentum', default=0.9)
    parser.add_argument('--weight_decay', type=float, help='weight decay', default=1e-4)

    parser.add_argument('--model', type=str, help='model architecture', default='resnet18')

    parser.add_argument('--method', type=str, help='loss function', choices=['infonce', 'infonce-strong'], default='infonce')
    parser.add_argument('--form', type=str, help='loss form (in or out)', default='out')
    parser.add_argument('--temp', type=float, help='supcon/infonce temperature', default=0.1)
    parser.add_argument('--epsilon', type=float, help='infonce epsilon', default=0.)
    parser.add_argument('--lr_epsilon', type=float, help='epsilon lr', default=1e-4)
    parser.add_argument('--lambd', type=float, help='lagrangian weight for debiasing', default=0.)
    parser.add_argument('--kld', type=float, help='weight of std term', default=0.)
    parser.add_argument('--alpha', type=float, help='infonce weight', default=1.)
    parser.add_argument('--beta', type=float, help='cross-entropy weight WITH supcon', default=0)

    parser.add_argument('--feat_dim', type=int, help='size of projection head', default=128)
    parser.add_argument('--mlp_lr', type=float, help='mlp lr', default=0.001)
    parser.add_argument('--mlp_lr_decay', type=str, help='mlp lr decay', default='constant')
    parser.add_argument('--mlp_max_iter', type=int, help='mlp training epochs', default=500)
    parser.add_argument('--mlp_optimizer', type=str, help='mlp optimizer', default='adam')
    parser.add_argument('--mlp_batch_size', type=int, help='mlp batch size', default=None)
    parser.add_argument('--test_freq', type=int, help='test frequency', default=1)
    parser.add_argument('--train_on_head', type=arg2bool, help="train clf on projection head features", default=True)

    parser.add_argument('--aug', type=arg2bool, help='training aug', default=True)
    parser.add_argument('--amp', action='store_true', help='use amp')
    return parser.parse_args()

def load_data(opts):
    imagenet_path = os.path.join(opts.data_dir, 'imagenet')
    imagenet_a_path = os.path.join(opts.data_dir, 'imagenet-a')
    feat_root = '/PATH/TO/FEATS'

    train_loader = get_imagenet(
        f'{imagenet_path}/train',
        batch_size=opts.batch_size,
        bias_feature_root=feat_root,
        train=True,
        aug=opts.aug,
        seed=opts.trial,
        ratio=0,
        load_bias_feature=True)

    val_loaders = {}
    val_loaders['biased'] = get_imagenet(
        f'{imagenet_path}/val',
        batch_size=128,
        train=False,
        aug=False)
    val_loaders['unbiased'] = get_imagenet(
        f'{imagenet_path}/val',
        batch_size=128,
        train=False,
        aug=False)
    val_loaders['ImageNet-A'] = get_imagenet(
        imagenet_a_path,
        batch_size=128,
        train=False,
        val_data='ImageNet-A')
    
    opts.dataset = 'imagenet9'
    opts.n_classes = 9
    return train_loader, val_loaders

def imagenet_unbiased_accuracy(
        outputs, labels, cluster_labels,
        num_correct, num_instance,
        num_cluster_repeat=3):
    for j in range(num_cluster_repeat):
        for i in range(outputs.size(0)):
            output = outputs[i]
            label = labels[i]
            cluster_label = cluster_labels[j][i]

            _, pred = output.topk(1, 0, largest=True, sorted=True)
            correct = pred.eq(label).view(-1).float()

            num_correct[j][label][cluster_label] += correct.item()
            num_instance[j][label][cluster_label] += 1

    return num_correct, num_instance


def n_correct(pred, labels):
    _, predicted = torch.max(pred.data, 1)
    n_correct = (predicted == labels).sum().item()
    return n_correct


def train(train_loader, model, criterion, optimizers, opts, epoch, scaler=None):
    loss = AverageMeter()
    nce = AverageMeter()
    ce = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    optimizer, optimizer_fc = optimizers

    all_outputs, all_labels = [], []

    t1 = time.time()
    for idx, (images, labels, _, _, bias_features) in enumerate(train_loader):
        data_time.update(time.time() - t1)

        images, labels, bias_features = images.to(opts.device), labels.to(opts.device), bias_features.to(opts.device)
        bsz = images.shape[0]
        
        warmup_learning_rate(opts, epoch, idx, len(train_loader), optimizer)

        with torch.set_grad_enabled(True):
            with torch.cuda.amp.autocast(scaler is not None):
                projected, feats, logits = model(images)
                running_nce = criterion(projected, feats, logits, labels, bias_features)
                running_ce = F.cross_entropy(logits, labels)

                running_loss = running_nce
                if opts.beta > 0:
                    running_loss = running_nce + opts.beta*running_ce
          
        optimizer.zero_grad()

        if optimizer_fc is not None:
            optimizer_fc.zero_grad()

        if scaler is None:
            if optimizer_fc is not None:
                running_ce.backward(retain_graph=True) # Backward cross-entropy from last layer
                optimizer_fc.step()
                optimizer.zero_grad() # Stop-gradient on the encoder

            running_loss.backward() # Backward infonce loss on the encoder
            optimizer.step()
        else:
            if optimizer_fc is not None:
                scaler.scale(running_ce).backward(retain_graph=True)
                scaler.step(optimizer_fc)
                optimizer.zero_grad()

            scaler.scale(running_loss).backward()
            scaler.step(optimizer)
            
            scaler.update()
        
        loss.update(running_loss.item(), bsz)
        nce.update(running_nce.item(), bsz)
        ce.update(running_ce.item(), bsz)
        batch_time.update(time.time() - t1)
        t1 = time.time()
        eta = batch_time.avg * (len(train_loader) - idx)

        if (idx + 1) % opts.print_freq == 0:
            print(f"Train: [{epoch}][{idx + 1}/{len(train_loader)}]:\t"
                  f"BT {batch_time.avg:.3f}\t"
                  f"ETA {datetime.timedelta(seconds=eta)}\t"
                  f"NCE {nce.avg:.3f}\t"
                  f"CE {ce.avg:.3f}\t"
                  f"loss {loss.avg:.3f}\t")
        
        all_outputs.append(logits.detach())
        all_labels.append(labels)
    
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    accuracy_train = accuracy(all_outputs, all_labels)[0]

    return loss.avg, accuracy_train, batch_time.avg, data_time.avg

def validate(val_loader,
             model,
             num_classes=9,
             num_clusters=9,
             num_cluster_repeat=3,
             key=None):
    model.eval()

    total = 0
    f_correct = 0
    num_correct = [np.zeros([num_classes, num_clusters]) for _ in range(num_cluster_repeat)]
    num_instance = [np.zeros([num_classes, num_clusters]) for _ in range(num_cluster_repeat)]

    for images, labels, bias_labels, index, _ in val_loader:
        images, labels = images.cuda(), labels.cuda()
        for bias_label in bias_labels:
            bias_label.cuda()

        _, _, output = model(images)

        batch_size = labels.size(0)
        total += batch_size

        if key == 'unbiased':
            num_correct, num_instance = imagenet_unbiased_accuracy(
                output.data, labels, bias_labels,
                num_correct, num_instance, num_cluster_repeat)
        else:
            f_correct += n_correct(output, labels)

    if key == 'unbiased':
        for k in range(num_cluster_repeat):
            x, y = [], []
            _num_correct, _num_instance = num_correct[k].flatten(), num_instance[k].flatten()
            for i in range(_num_correct.shape[0]):
                __num_correct, __num_instance = _num_correct[i], _num_instance[i]
                if __num_instance >= 10:
                    x.append(__num_instance)
                    y.append(__num_correct / __num_instance)
            f_correct += sum(y) / len(x)

        return f_correct / num_cluster_repeat
    else:
        return f_correct / total

if __name__ == '__main__':
    opts = parse_arguments()
    set_seed(opts.trial)

    train_loader, val_loaders = load_data(opts)
    model, infonce = load_model(opts)
    (optimizer, scheduler), (optimizer_fc, scheduler_fc) = load_optimizer(model, infonce, opts)
    
    if opts.batch_size > 256:
        opts.warm = True
    
    if opts.warm:
        opts.warm_epochs = 10
        opts.warmup_from = 0.01
        opts.model = f"{opts.model}_warm"
        
        if opts.lr_decay == 'cosine':
            eta_min = opts.lr * (0.1 ** 3)
            opts.warmup_to = eta_min + (opts.lr - eta_min) * (1 + math.cos(math.pi * opts.warm_epochs / opts.epochs)) / 2
        else:
            opts.warmup_to = opts.lr

    ensure_dir(opts.log_dir)
    run_name = (f"{opts.method}_{opts.form}_{opts.dataset}_{opts.model}_"
                f"{opts.optimizer}"
                f"bsz{opts.batch_size}_lr{opts.lr}_t{opts.temp}_eps{opts.epsilon}_"
                f"lr-eps{opts.lr_epsilon}_feat{opts.feat_dim}_"
                f"{'identity_' if opts.train_on_head else 'head_'}"
                f"alpha{opts.alpha}_beta{opts.beta}_lambda{opts.lambd}_kld{opts.kld}_"
                f"mlp_lr{opts.mlp_lr}_mlp_optimizer_{opts.mlp_optimizer}_"
                f"trial{opts.trial}")
    tb_dir = os.path.join(opts.log_dir, run_name)
    opts.model_class = model.__class__.__name__
    opts.criterion = infonce
    opts.optimizer_class = optimizer.__class__.__name__
    opts.scheduler = scheduler.__class__.__name__ if scheduler is not None else None

    wandb.init(project="contrastive-learning-debiasing", config=opts, name=run_name, sync_tensorboard=True)
    print('Config:', opts)
    print('Model:', model)
    print('Criterion:', infonce)
    print('Optimizer:', optimizer)
    print('Scheduler:', scheduler)

    writer = torch.utils.tensorboard.writer.SummaryWriter(tb_dir)
    
  
    def criterion_cont(projected, feats, logits, labels, bias_features):
        loss = opts.alpha * infonce(projected, labels)
        if opts.lambd != 0:
            feats = F.normalize(feats)
            R = opts.lambd * losses.lagrangian_constraint_cont(feats, labels, 
                                                               bias_features, 1.0,
                                                               kld=opts.kld) 
            loss += R
        return loss

    def criterion_disc(projected, feats, logits, labels, bias_labels):
        loss = opts.alpha * infonce(projected, labels)
        if opts.lambd != 0:
            feats = F.normalize(feats)
            R = opts.lambd * losses.lagrangian_constraint(feats, labels, 
                                                          bias_labels, 1.0,
                                                          kld=opts.kld) 
            loss += R
        return loss
            
    scaler = torch.cuda.amp.GradScaler() if opts.amp else None
    if opts.amp:
        print("Using AMP")
    
    best_accs = pretty_dict(**{'biased': 0, 'unbiased': 0, 'ImageNet-A': 0})
    best_epochs = pretty_dict(**{'biased': 0, 'unbiased': 0, 'ImageNet-A': 0})

    start_time = time.time()
    for epoch in range(1, opts.epochs + 1):
        t1 = time.time()
        loss_train, accuracy_train, batch_time, data_time = train(train_loader, model, criterion_cont, (optimizer, optimizer_fc), opts, epoch, scaler)
        t2 = time.time()

        writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], epoch)
        writer.add_scalar("train/loss", loss_train, epoch)
        writer.add_scalar("train/acc@1", accuracy_train, epoch)
        if "auto" in opts.method:
            writer.add_scalar("train/epsilon", infonce.epsilon, epoch)

        writer.add_scalar("BT", batch_time, epoch)
        writer.add_scalar("DT", data_time, epoch)
        print(f"epoch {epoch}, total time {t2-start_time:.2f}, epoch time {t2-t1:.3f} "
              f"acc {accuracy_train:.2f} loss {loss_train:.4f}")
        
        if scheduler is not None:
            scheduler.step()

        if (epoch % opts.test_freq == 0) or epoch == 1 or epoch == opts.epochs:
            for name, val_loader in val_loaders.items():
                accuracy_test = validate(val_loader, model, key=name)
                writer.add_scalar(f"{name}/acc@1", accuracy_test, epoch)
                print(f"{name} accuracy {accuracy_test:.2f}")

                #aligned_sim = np.histogram(aligned_sim.cpu(), bins=min(512, aligned_sim.shape[0]))
                #conflicting_sim = np.histogram(conflicting_sim.cpu(), bins=min(512, conflicting_sim.shape[0]))
                # try:
                #     writer.add_histogram('test/aligned_sim', aligned_sim[0], epoch, bins=256, max_bins=512)
                #     writer.add_histogram('test/conflicting_sim', conflicting_sim[0], epoch, bins=256, max_bins=512)

                #     print('aligned_sim', aligned_sim[1], 'conflicting_sim', conflicting_sim[1])
                #     writer.add_scalar('test/aligned_sim_mean', aligned_sim[1], epoch)
                #     writer.add_scalar('test/conflicting_sim_mean', conflicting_sim[1], epoch)
                # except:
                #     pass

                if accuracy_test > best_accs[name]:
                    best_accs[name] = accuracy_test
                    best_epochs[name] = epoch
    
                writer.add_scalar(f"{name}/best_acc@1", best_accs[name], epoch)
    
        print(f"best accuracy:", best_accs, best_epochs)

    # clf_train_accuracy, clf_test_accuracy = test_mlp(train_loader, test_loader, model, opts)
    # print("CLF train accuracy:", clf_train_accuracy)
    # print("CLF test accuracy:", clf_test_accuracy)
    # writer.add_scalar("train/clf_acc@1", clf_train_accuracy, opts.epochs)
    # writer.add_scalar("test/clf_acc@1", clf_test_accuracy, opts.epochs)