#import sys
#sys.path.append('../')

from dataloader_ffcv import create_dataloader
from train_utils import *


from tqdm import tqdm
import argparse

from pathlib import Path
import torch.nn as nn
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import roc_curve, auc, precision_recall_curve

from torch.cuda.amp import GradScaler, autocast


def compute_loss_value(args, model_ascent):

    model_path = f'Results/{args.dataset}/{args.attack}/Poisonratio_{args.poison_ratio}/{args.arch}/Trial {args.trialno}'
    pathname = f'{model_path}/ABL'
    Path(pathname).mkdir(parents=True)
    
    if args.dataset == 'cifar10':
        train_no = 50000
    elif args.dataset == 'imagenet200':
        train_no = 100000
    elif args.dataset == 'tinyimagenet':
        train_no = 100000
        
        
    # Calculate loss value per example
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    model_ascent.eval()
    losses_record = []


    example_data_loader, _,_ = create_dataloader(args, 1, '', device, partition='None', seq=True)
    poison_label_full = torch.zeros(train_no)

    for idx, (img, target, _ , poison_label) in tqdm(enumerate(example_data_loader, start=0)):
        
        img = img/255.
        
        with torch.no_grad():
            with autocast():
                output = model_ascent(img)
                loss = criterion(output, target.unsqueeze(0))
    
        losses_record.append(loss.item())
        poison_label_full[idx] = poison_label

    losses_idx = np.argsort(np.array(losses_record))   # get the index of examples by loss value in ascending order

    # Show the lowest 10 loss values
    losses_record_arr = np.array(losses_record)
    print('Top ten loss value:', losses_record_arr[losses_idx[:10]])
    
    
    perm = losses_idx[0: int(len(losses_idx) * opt.isolation_ratio)]
    
    ABL_poison_label =  torch.zeros(train_no)
    for idx in perm:
        ABL_poison_label[idx] = 1
        
    roc_auc = roc_auc_score(poison_label_full, ABL_poison_label)
    print(roc_auc)
    fpr, tpr, _ = roc_curve(poison_label_full, ABL_poison_label)
    print(tpr)
    print(fpr)

    torch.save(poison_label_full, f'{pathname}/poisonlab_true.pt')
    torch.save(ABL_poison_label, f'{pathname}/ABL_pred.pt')
    with open(f'{pathname}/AUROC_ABL', 'w') as f:
        json.dump(roc_auc, f, indent=2)
        

def train_step(args, train_loader, scaler, model_ascent, optimizer, criterion, epoch):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_ascent.train()

    for idx, (img, target, _, _) in enumerate(train_loader, start=1):
        
        img = img/255.
 
        if args.gradient_ascent_type == 'LGA':
            with autocast():
                output = model_ascent(img)
                loss = criterion(output, target)
                # add Local Gradient Ascent(LGA) loss
                loss_ascent = torch.sign(loss - opt.gamma) * loss

        elif args.gradient_ascent_type == 'Flooding':
            with autocast():
                output = model_ascent(img)
                loss = criterion(output, target)
                # add flooding loss
                loss_ascent = (loss - opt.flooding).abs() + opt.flooding

        else:
            raise NotImplementedError

        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), img.size(0))
        top1.update(prec1.item(), img.size(0))

        
        optimizer.zero_grad()
        scaler.scale(loss_ascent).backward()
        scaler.step(optimizer)
        scaler.update()


        if idx % args.print_freq == 0:
            print('Epoch[{0}]:[{1:03}/{2:03}] '
                  'Loss:{losses.val:.4f}({losses.avg:.4f})  '
                  'Prec@1:{top1.val:.2f}({top1.avg:.2f})  '.format(epoch, idx, len(train_loader), losses=losses, top1=top1))




def train(args, device):
    # Load models
    print('----------- Network Initialization --------------')
    
    #path = f'Results/{opt.dataset}/{opt.attack}/Poisonratio_{opt.poison_ratio}/{opt.arch}/Trial {opt.trialno}'
    #pathname = f'{path}/ABL'
    #Path(pathname).mkdir(parents=True)
    #print('finished model init...')
    
    model_ascent = build_model(args)
    model_ascent.to(device)

    # initialize optimizer
    optimizer = torch.optim.SGD(model_ascent.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # define loss functions
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
       
    print('----------- Data Initialization --------------')
    train_loader, _,_ = create_dataloader(args, 128, '', device, partition='None', seq=False)

    print('----------- Train Initialization --------------')
    for epoch in range(0, args.tuning_epochs):

        adjust_learning_rate(optimizer, epoch, args)
        train_step(args, train_loader, scaler, model_ascent, optimizer, criterion, epoch + 1)

    return  model_ascent


def adjust_learning_rate(optimizer, epoch, opt):
    if epoch < opt.tuning_epochs:
        lr = opt.lr
    else:
        lr = 0.01
    print('epoch: {}  lr: {:.4f}'.format(epoch, lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



def main(opt, device):
    print('----------- Train isolated model -----------')
    ascent_model = train(opt, device)

    print('----------- Calculate loss value per example -----------')
    compute_loss_value(opt, ascent_model)



if (__name__ == '__main__'):
    
    parser = argparse.ArgumentParser()

    # various path
    parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
    parser.add_argument('--arch', type=str, default='res18', help='model architecture')

    # training hyper parameters
    parser.add_argument('--print_freq', type=int, default=2, help='frequency of showing training results on console')
    parser.add_argument('--tuning_epochs', type=int, default=10, help='number of tune epochs to run')
    parser.add_argument('--batch_size', type=int, default=128, help='The size of batch')
    parser.add_argument('--lr', type=float, default=0.1, help='initial learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--isolation_ratio', type=float, default=0.1, help='ratio of isolation data')
    parser.add_argument('--gradient_ascent_type', type=str, default='Flooding', help='type of gradient ascent')
    parser.add_argument('--gamma', type=int, default=0.5, help='value of gamma')
    parser.add_argument('--flooding', type=int, default=0.5, help='value of flooding')

    parser.add_argument('--poison_ratio', default=0.1, type=float, help='Poison Ratio')
    parser.add_argument('--attack', type=str, help='Give attack name')
    parser.add_argument('--save_samples', type=str, default='False', help='Give attack name') 
    parser.add_argument('--trialno',  type=int)
    
    opt = parser.parse_args()
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    main(opt, device)
