import sys
import os
from tqdm import tqdm
import numpy as np
import argparse
import torch
from torch import nn
sys.path.append("./")
sys.path.append(os.getcwd())
print(os.getcwd())
from utils.defense.utils_dst.dataloader_bd import normalization

def calculate_consistency(args, dataloader, model):
    model.eval()

    for i, (inputs, labels, _, isCleans, gt_labels) in enumerate(dataloader):
        inputs1, inputs2 = inputs[0], inputs[2]
        inputs1, inputs2 = normalization(args, inputs1), normalization(args, inputs2)  # Normalize
        inputs1, inputs2, labels, gt_labels = inputs1.to(args.device), inputs2.to(args.device), labels.to(args.device), gt_labels.to(args.device)
        clean_idx, poison_idx = torch.where(isCleans == True), torch.where(isCleans == False)

        ### Feature ###
        if hasattr(model, "module"):  # abandon FC layer
            features_out = list(model.module.children())[:-1]
        else:
            features_out = list(model.children())[:-1]
        modelout = nn.Sequential(*features_out).to(args.device)
        features1, features2 = modelout(inputs1), modelout(inputs2)
        features1, features2 = features1.view(features1.size(0), -1), features2.view(features2.size(0), -1)

        ### Calculate consistency ###
        feature_consistency = torch.mean((features1 - features2)**2, dim=1)

        ### Save ###
        draw_features = feature_consistency.detach().cpu().numpy()
        draw_clean_features = feature_consistency[clean_idx].detach().cpu().numpy()
        draw_poison_features = feature_consistency[poison_idx].detach().cpu().numpy()
        
        
        f_path = os.path.join(args.save_path, 'data_produce')

        if not os.path.exists(f_path):
            os.makedirs(f_path)
        f_all = os.path.join(f_path,'all.txt')
        f_clean = os.path.join(f_path,'clean.txt')
        f_poison = os.path.join(f_path,'poison.txt')
        with open(f_all, 'ab') as f:
            np.savetxt(f, draw_features, delimiter=" ")
        with open(f_clean, 'ab') as f:
            np.savetxt(f, draw_clean_features, delimiter=" ")
        with open(f_poison, 'ab') as f:
            np.savetxt(f, draw_poison_features, delimiter=" ")
    return

def calculate_gamma(args):
    args.clean_ratio = 0.20
    args.poison_ratio = 0.05

    f_path = os.path.join(args.save_path, 'data_produce')
    f_all = os.path.join(f_path,'all.txt')

    all_data = np.loadtxt(f_all)
    all_size = all_data.shape[0] # 50000

    clean_size = int(all_size * args.clean_ratio) # 10000
    poison_size = int(all_size * args.poison_ratio) # 2500

    new_data = np.sort(all_data) # in ascending order
    gamma_low = new_data[clean_size]
    gamma_high = new_data[all_size-poison_size]
    print("gamma_low: ", gamma_low)
    print("gamma_high: ", gamma_high)
    return gamma_low, gamma_high

def separate_samples(args, trainloader, model):
    gamma_low, gamma_high = args.gamma_low, args.gamma_high
    model.eval()
    clean_samples, poison_samples, suspicious_samples = [], [], []

    for i, (inputs, labels, _, _, gt_labels) in enumerate(trainloader):
        if args.debug and i==10001:
            break
        if i % 1000 == 0:
            print("Processing samples:", i)
        inputs1, inputs2 = inputs[0], inputs[2]

        ### Prepare for saved ###
        img = inputs1
        img = img.squeeze()
        target = labels.squeeze()
        img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8')
        target = target.cpu().numpy()

        inputs1, inputs2 = normalization(args, inputs1), normalization(args, inputs2)  # Normalize
        inputs1, inputs2, labels, gt_labels = inputs1.to(args.device), inputs2.to(args.device), labels.to(args.device), gt_labels.to(args.device)

        ### Features ###
        if hasattr(model, "module"):  # abandon FC layer
            features_out = list(model.module.children())[:-1]
        else:
            features_out = list(model.children())[:-1]
        modelout = nn.Sequential(*features_out).to(args.device)
        features1, features2 = modelout(inputs1), modelout(inputs2)
        features1, features2 = features1.view(features1.size(0), -1), features2.view(features2.size(0), -1)

        ### Compare consistency ###
        feature_consistency = torch.mean((features1 - features2)**2, dim=1)
        # feature_consistency = feature_consistency.detach().cpu().numpy()

        ### Separate samples ###
        if feature_consistency.item() <= gamma_low:
            flag = 0
            clean_samples.append((img, target, flag))
        elif feature_consistency.item() >= gamma_high:
            flag = 2
            poison_samples.append((img, target, flag))
        else:
            flag = 1
            suspicious_samples.append((img, target, flag))

    ### Save samples ###

    folder_path = os.path.join(args.save_path, 'data_produce')

    data_path_clean = os.path.join(folder_path, 'clean_samples.npy')
    data_path_poison = os.path.join(folder_path, 'poison_samples.npy')
    data_path_suspicious = os.path.join(folder_path, 'suspicious_samples.npy')
    np.save(data_path_clean, clean_samples)
    np.save(data_path_poison, poison_samples)
    np.save(data_path_suspicious, suspicious_samples)
