# adapted from https://github.com/SCLBD/Effective_backdoor_defense/blob/main/train_attack_noTrans.py
import sys
import os
from tqdm import tqdm
import numpy as np
import csv
from PIL import Image
from train_utils import *
from trainnew import train

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms


import argparse
from pathlib import Path
from dataloader_ffcv import create_dataloader
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as transforms

import os
import sys
import time 

from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import roc_curve, auc, precision_recall_curve

def train_epoch(arg, trainloader, model, optimizer, scheduler, criterion, epoch, scaler):
    
    if arg.dataset == 'cifar10':
        num_classes = 10
    elif arg.dataset == 'imagenet200':
        num_classes = 200
    elif arg.dataset == 'tinyimagenet':
        num_classes = 200
        
    model.train()

    total_clean, total_poison = 0, 0
    total_clean_correct, total_attack_correct, total_robust_correct = 0, 0, 0
    train_loss = 0
    
    iterator = tqdm(enumerate(trainloader), total=len(trainloader))    
    for i, (inputs, labels, gt_labels, isCleans) in iterator:
        
        inputs = inputs/255.
      
        clean_idx, poison_idx = torch.where(isCleans == 0)[0], torch.where(isCleans == 1)[0]

        # Features and Outputs
        with autocast():
            outputs = model(inputs)
            features_out = list(model.children())[:-1]  # abandon FC layer
            modelout = nn.Sequential(*features_out).to(device)
            features = modelout(inputs)
            features = features.view(features.size(0), -1)

        # Calculate intra-class loss
        centers = []
        for j in range(num_classes):
            j_idx = torch.where(labels == j)[0]
            if j_idx.shape[0] == 0:
                continue
            j_features = features[j_idx]
            j_center = torch.mean(j_features, dim=0)
            centers.append(j_center)
        
        #import pdb;pdb.set_trace()
        centers = torch.stack(centers, dim=0)
        centers = F.normalize(centers, dim=1)
        similarity_matrix = torch.matmul(centers, centers.T)
        mask = torch.eye(similarity_matrix.shape[0], dtype=torch.bool).to(device)
        similarity_matrix[mask] = 0.0
        loss = torch.mean(similarity_matrix)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        


        train_loss += loss.item()
        total_clean_correct += torch.sum(torch.argmax(outputs[:], dim=1) == labels[:])
        total_attack_correct += torch.sum(torch.argmax(outputs[poison_idx], dim=1) == arg.target_label)
        total_robust_correct += torch.sum(torch.argmax(outputs[:], dim=1) == gt_labels[:])
        total_clean += inputs.shape[0]
        total_poison += inputs[poison_idx].shape[0]
        avg_acc_clean = total_clean_correct * 100.0 / total_clean
        avg_acc_attack = total_attack_correct * 100.0 / total_poison
        avg_acc_robust = total_robust_correct * 100.0 / total_clean
        
        iterator.set_description(f"Epoch {epoch}") ## FIND LR!!!!!
        losss = train_loss/(i + 1)
        iterator.set_postfix(loss=losss, train_acc=avg_acc_clean.item(), train_asr=avg_acc_attack.item(), train_racc=avg_acc_robust.item())
        iterator.refresh()

        #progress_bar(i, len(trainloader),
        #             'Epoch: %d | Loss: %.3f | Train ACC: %.3f%% (%d/%d) | Train ASR: %.3f%% (%d/%d) | Train R-ACC: %.3f%% (%d/%d)' % (
        #             epoch, train_loss / (i + 1), avg_acc_clean, total_clean_correct, total_clean, avg_acc_attack,
        #             total_attack_correct, total_poison, avg_acc_robust, total_robust_correct, total_clean))
    scheduler.step()
    return train_loss / (i + 1), avg_acc_clean, avg_acc_attack, avg_acc_robust

def train(train_loader, model, criterion, optimizer, scheduler, scaler, epoch, args, device):
    
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()
    

    #start = time.time()
    iterator = tqdm(enumerate(train_loader), total=len(train_loader))

    for i, (image, target, _, _) in iterator:
        

        image = image/255.
        #target = target.to(device)

        #import pdb;pdb.set_trace()
        optimizer.zero_grad(set_to_none=True)
        with autocast():
            output_clean = model(image)
            #import pdb;pdb.set_trace()
            loss = criterion(output_clean, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        #scheduler.step()
        
        output = output_clean.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), image.size(0))
        top1.update(prec1.item(), image.size(0))

        iterator.set_description(f"Epoch {epoch} | LR {optimizer.param_groups[0]['lr']:.2f}") ## FIND LR!!!!!
        iterator.set_postfix(loss=loss.item(), accuracy=prec1.item())
        iterator.refresh()
    
    return model
        



def main(arg, device):
    
    model_path = f'Results/{args.dataset}/{args.attack}/Poisonratio_{args.poison_ratio}/{args.arch}/Trial {args.trialno}'
    pathname = f'{model_path}/SD_FCT'
    Path(pathname).mkdir(parents=True)


    batch_size = arg.batch_size
    
    if args.dataset == 'cifar10':
        train_no = 50000
    elif args.dataset == 'imagenet200':
        train_no = 100000
    elif args.dataset == 'tinyimagenet':
        train_no = 100000

    ##############################################################################################################
    #######################################  First 2 training epochs #############################################
    model = build_model(arg)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
    # Dataset
    trainloader, _,_ = create_dataloader(arg, batch_size, '', device, partition='None', seq=False)
   
    
    optimizer = torch.optim.SGD(model.parameters(), lr=arg.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    
    for epoch in range(2):
        model = train(trainloader, model, criterion, optimizer, scheduler, scaler, epoch, arg, device)
        scheduler.step()
    
    torch.save(model, pathname + "/model_init2.pt") 
    # model_path = f'Results/{args.dataset}/{args.attack}/Poisonratio_{args.poison_ratio}/{args.arch}/Trial {args.trialno}'
    # model = torch.load(f'{model_path}/model.pt')
    # model.to(device)
    
    ##############################################################################################################
    ####################################### Finetuning training epochs #############################################
    
    
    optimizer = torch.optim.SGD(model.parameters(), lr=arg.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    scaler = GradScaler()

    # Training and Testing
    best_acc = 0
    criterion = nn.CrossEntropyLoss()


    for epoch in range(5): ## is 10 okay ???
        train_loss, train_acc, train_asr, train_racc = train_epoch(arg, trainloader, model, optimizer, scheduler,
                                                                   criterion, epoch, scaler)    
    torch.save(model, pathname + "/model_ft.pt")
    ##############################################################################################################
    ####################################### Find Feature Consistency #############################################
    
    model.eval()
    
    poison_label_full = torch.zeros(train_no)
    feature_consistency_full = torch.zeros(train_no)

    batch_size = 100
    trainloader, _,_ = create_dataloader(arg, batch_size, '', device, partition='None', seq=True)


    for ix, (inputs, labels, gt_labels, poison_label) in enumerate(trainloader):
        
        #import pdb;pdb.set_trace()
        inputs1 = inputs/255.
        transform_fct = transforms.Compose([transforms.RandomRotation(180),transforms.RandomAffine(degrees=0, translate=(0.2, 0.2))])
        inputs2 = transform_fct(inputs1)
        
        poison_label_full[ix*batch_size:ix*batch_size + batch_size] = poison_label
       
        with torch.no_grad():
            with autocast():
            ### Feature ###
                features_out = list(model.children())[:-1] # abandon FC layer
                modelout = nn.Sequential(*features_out).to(device)
                features1, features2 = modelout(inputs1), modelout(inputs2)
                features1, features2 = features1.view(features1.size(0), -1), features2.view(features2.size(0), -1)
     
        ### Calculate consistency ###
        feature_consistency_full[ix*batch_size:ix*batch_size + batch_size] = torch.mean((features1 - features2)**2, dim=1)
        
    
    
    roc_auc = roc_auc_score(poison_label_full, feature_consistency_full)
    torch.save(poison_label_full, f'{pathname}/poisonlab_true.pt')
    torch.save(feature_consistency_full, f'{pathname}/FCT_pred.pt')
    with open(f'{pathname}/AUROC_SDFCT', 'w') as f:
        json.dump(roc_auc, f, indent=2)

    print(roc_auc)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Total: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()



def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    # various path
    parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
    parser.add_argument('--arch', type=str, default='res18', help='model architecture')

    # training hyper parameters
    parser.add_argument('--batch_size', type=int, default=128, help='The size of batch')
    parser.add_argument('--lr', type=float, default=0.01, help='initial learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
  

    parser.add_argument('--poison_ratio', default=0.1, type=float, help='Poison Ratio')
    parser.add_argument('--attack', type=str, help='Give attack name')
    parser.add_argument('--save_samples', type=str, default='False', help='Give attack name') 
    parser.add_argument('--target_label', default=1, type=int)
    parser.add_argument('--trialno',  type=int)

    args = parser.parse_args()
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    main(args, device)
