import numpy as np
import torch
from torch.utils.data import Subset, Dataset

def infinitify(loader):
    while True:
        for x, y in loader:
            yield x, y
            
class Custom_Dataset(Dataset):

    def __init__(self, dataset, labels):

        self.dataset = dataset
        self.labels = labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, ind):
        
        image = self.dataset[ind][0]
        return image, self.labels[ind]


            
def get_ood_loader(args, model, labeled_dataset, unlabeled_dataset, task_number):
    
    labeled_loader = torch.utils.data.DataLoader(
    labeled_dataset, batch_size=args.batch_size, shuffle=False,
    num_workers=4)

    unlabeled_loader = torch.utils.data.DataLoader(
    unlabeled_dataset, batch_size=args.batch_size, shuffle=False,
    num_workers=4)
    
    unlabeled_features = torch.tensor([])
    unlabeled_labels = torch.tensor([])
    labeled_features = torch.tensor([])
    labeled_labels = torch.tensor([])

    iteration = 5

    for i in range(iteration):
        for x, y in labeled_loader:
            x = x[0].to(args.device)
            with torch.no_grad():
                feature = model(x)
                labeled_features = torch.cat([labeled_features, feature.cpu()])
                y = torch.split(y, len(x))[0]
                labeled_labels = torch.cat([labeled_labels, y])

    for i in range(iteration):
        for x, y in unlabeled_loader:
            x = x[0].to(args.device)
            with torch.no_grad():
                feature = model(x)
                unlabeled_features = torch.cat([unlabeled_features, feature.cpu()])
                if i == 0:
                    y = torch.split(y, len(x))[0]
                    unlabeled_labels = torch.cat([unlabeled_labels, y])
    
    
    total_number_of_data = args.num_main_unlabeled + args.num_peripheral_unlabeled

    related_data = 0
    unrelated_data = 0
    
    for i in range(iteration):
        related_data += unlabeled_features[i * total_number_of_data: i * total_number_of_data + args.num_main_unlabeled]
        unrelated_data += unlabeled_features[i * total_number_of_data + args.num_main_unlabeled: (i+1) * total_number_of_data]
        
    related_data /= iteration
    unrelated_data /= iteration

    unlabeled_features = torch.cat([related_data, unrelated_data])
    
    prototypes = torch.tensor([])
    
    num_classes = len(torch.unique(labeled_labels))
    
    print('ood num classes: ', num_classes)

    for i in range(num_classes):
        indexes = labeled_labels == i
        prototype = torch.mean(labeled_features[indexes], keepdims = True, dim = 0)
        prototypes = torch.cat([prototypes, prototype])

        
    unlabeled_sim = torch.zeros(len(unlabeled_features), num_classes)
    cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)

    for i in range(len(unlabeled_features)):
        for j in range(num_classes):
            unlabeled_sim[i,j] = cos(unlabeled_features[i], prototypes[j])

    unlabeled_sim_max, unlabeled_sim_argmax = torch.max(unlabeled_sim, axis = 1)
    
    
    labeled_sim = torch.zeros(len(labeled_features), num_classes)

    for i in range(len(labeled_features)):
        for j in range(num_classes):
            labeled_sim[i,j] = cos(labeled_features[i], prototypes[j])

    labeled_sim_max = torch.max(labeled_sim, axis = 1)[0]
    
    th = torch.mean(labeled_sim_max) + args.ood_in_dist_var_coeff * torch.var(labeled_sim_max)
    index_list = list(np.where(unlabeled_sim_max.numpy() > th.item())[0])
    
    print('number of in-dist detected from in-dist dataset = {}'.format(sum(np.array(index_list) < args.num_main_unlabeled)))
    print('number of in-dist detected from out-dist dataset = {}'.format(sum(np.array(index_list) >= args.num_main_unlabeled)))
    
    selected_dataset = Subset(unlabeled_dataset, index_list)
    gathered_dataset = torch.utils.data.ConcatDataset([selected_dataset, labeled_dataset])
    
    print('OOD gathered dataset size: {}'.format(len(gathered_dataset)))
    
    gathered_loader = torch.utils.data.DataLoader(
        gathered_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers)
    
    
    # save different values of tp and fp for calculating AUROC and precision
    in_dist = []
    out_dist = []
    for i in range(-10, 30):
        th = torch.mean(labeled_sim_max) - i * torch.var(labeled_sim_max)
        index_list = list(np.where(unlabeled_sim_max.numpy() > th.item())[0])
        in_dist.append(sum(np.array(index_list) < args.num_main_unlabeled))
        out_dist.append(sum(np.array(index_list) >= args.num_main_unlabeled))

    with open(args.directory + 'ood_result.txt', 'a') as f:
        f.write("\n")
        f.write("task number : %s\n" % (task_number + 1))
        for item in in_dist:
            f.write("%s," % item)
        f.write("\n")
        for item in out_dist:
            f.write("%s," % item)

    
    
    
    # create pseudo-labels with a more confident threshold
    with_confidence_th = torch.mean(labeled_sim_max) + args.ood_pl_var_coeff * torch.var(labeled_sim_max)
    index_list_with_confidence = list(np.where(unlabeled_sim_max.numpy() > with_confidence_th.item())[0])

    print('number of in-dist detected from in-dist dataset with confidence = {}'.format(sum(np.array(index_list_with_confidence) < args.num_main_unlabeled)))
    print('number of in-dist detected from out-dist dataset with confidence = {}'.format(sum(np.array(index_list_with_confidence) >= args.num_main_unlabeled)))

    selected_dataset_with_confidence = Subset(unlabeled_dataset, index_list_with_confidence)
    gathered_dataset_with_confidence = torch.utils.data.ConcatDataset([selected_dataset_with_confidence, labeled_dataset])
    labeled_labels = []
    for data in labeled_dataset:
        labeled_labels.append(data[1])
    labeled_labels = torch.tensor(labeled_labels)
    labels = torch.cat([unlabeled_sim_argmax[index_list_with_confidence], labeled_labels])

    with_confidence_gathered_dataset = Custom_Dataset(gathered_dataset_with_confidence, labels)

    print('OOD gathered dataset with confidence size: {}'.format(len(with_confidence_gathered_dataset)))

    gathered_loader_with_confidence = torch.utils.data.DataLoader(
        with_confidence_gathered_dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers)

    return gathered_loader_with_confidence, infinitify(gathered_loader)