import csv
import os
from copy import deepcopy

import numpy as np
import torch
import torch.nn.functional as F

class Strategy:
    def __init__(self, train_args,unlabeled_originlabels,unlabeled_img_path, add_ratio,num_classes,unlabeled_target):
        self.select_strategy = train_args.select_strategy
        self.select_ratio = train_args.select_ratio
        self.select_type = train_args.select_type
        self.unlabeled_originlabels = unlabeled_originlabels
        self.classifier_name = train_args.classifier_name
        self.dataset_name = train_args.dataset_name
        self.add_ratio = add_ratio
        self.unlabeled_img_path = unlabeled_img_path
        self.num_train_set = train_args.num_train_set
        self.num_classes = num_classes
        self.unlabeled_target = unlabeled_target
        self.begin_select_ratio = train_args.begin_select_ratio
    def chose(self):
        pass
    def quarry(self):
        pass
    def update(self, idxs_lb):
        self.idxs_lb = idxs_lb

    def compute_entropy(self, distance_softmax):
        # log_probs = torch.log2(distance_softmax+0.000001)
        log_probs = np.log2(distance_softmax+0.000001)
        
        # print(np.array(distance_softmax).shape)
        # entropy = -1 * torch.sum(distance_softmax * log_probs, dim=1)
        entropy = -1 * np.sum(distance_softmax * log_probs, axis=1)
        return entropy
    def compute_entropy_dim1(self, distance_softmax):
        log_probs = torch.log2(distance_softmax+0.000001)
        entropy = -1 * torch.sum(distance_softmax * log_probs)
        return entropy
    def make_csv_path(self,rtflag=False):
        filepath = './Selcetion/{}/{}/{}/{}/'.format(self.dataset_name, self.classifier_name, self.select_strategy,
                                                     self.select_type)

        if os.path.exists(filepath):
            pass
        else:
            os.makedirs(filepath)

        ft = open(str('{}{}_{}.csv'.format(filepath, self.dataset_name, self.select_ratio)), 'w', newline='')
        ft_csv = csv.writer(ft)
        if rtflag is True:
            rt = open(str('{}{}_{}_{}.csv'.format(filepath, self.dataset_name, self.select_ratio, 'remove')), 'w', newline='')
            rt_csv = csv.writer(rt)
            return ft_csv,rt_csv
        else:
            return ft_csv

    def predict_all_representations(self, X, Y):
        return self.model.predict_all_representations(X, Y)

    def predict_embedding_prob(self, X_embedding):
        return self.model.predict_embedding_prob(X_embedding)

    def predict_prob_dropout(self, X, Y, n_drop):
        return self.model.predict_prob_dropout(X, Y, n_drop)

    def predict_prob_dropout_split(self, X, Y, n_drop):
        return self.model.predict_prob_dropout_split(X, Y, n_drop)

    def predict_prob_embed_dropout_split(self, X, Y, n_drop):
        return self.model.predict_prob_embed_dropout_split(X, Y, n_drop)

    def get_embedding(self, X, Y):
        return self.model.get_embedding(X, Y)

    def get_grad_embedding(self, X, Y, is_embedding=False):
        return self.model.get_grad_embedding(X, Y, is_embedding)
