import time

import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
from . import helper_functions
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class DatasetBuilder(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return (self.data.shape[0])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {'data': self.data[idx], 'index': idx}
        return sample


class encoder_a(nn.Module):
    def __init__(self, kernel_size,hdn_size,d):
        super(encoder_a, self).__init__()
        self.fc1 = nn.Linear(d-kernel_size, hdn_size) #F network
        self.activation1 = nn.Tanh()
        self.fc2 = nn.Linear(hdn_size, hdn_size*2)
        self.activation2 = nn.LeakyReLU(0.2)
        self.fc3 = nn.Linear(hdn_size*2, hdn_size)
        self.activation3 = nn.LeakyReLU(0.2)
        self.batchnorm_1 = nn.BatchNorm1d(d-kernel_size+1)
        self.batchnorm_2 = nn.BatchNorm1d(d-kernel_size+1)
        self.fc1_y = nn.Linear(kernel_size, int(hdn_size/4)) #G network
        self.activation1_y = nn.LeakyReLU(0.2)
        self.fc2_y = nn.Linear(int(hdn_size/4), int(hdn_size/2))
        self.activation2_y = nn.LeakyReLU(0.2)
        self.fc3_y = nn.Linear(int(hdn_size/2), hdn_size)
        self.activation3_y = nn.LeakyReLU(0.2)
        self.kernel_size = kernel_size
        self.batchnorm1_y=nn.BatchNorm1d(d-kernel_size+1)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        y,x = helper_functions.positive_matrice_builder(x, self.kernel_size)
        x = self.activation1(self.fc1(x))
        x=self.batchnorm_1(x)
        x = self.activation2(self.fc2(x))
        x=self.batchnorm_2(x)
        x = self.activation3(self.fc3(x))
        y = self.activation1_y(self.fc1_y(y))
        y=self.batchnorm1_y(y)
        y = self.activation2_y(self.fc2_y(y))
        y = self.activation3_y(self.fc3_y(y))
        x=nn.functional.normalize(x,dim=1)
        y=nn.functional.normalize(y,dim=1)
        x=nn.functional.normalize(x,dim=2)
        y=nn.functional.normalize(y,dim=2)
        return (x, y)


class SCADTrainer():
    def __init__(self, input_dim, tau, k, weight_decay, hidden_dim, batch_size=256, device='cuda'):
        self.num_epochs = 200  # 2000 default
        self.no_btchs = batch_size
        self.no_negatives=1000
        self.temperature=tau
        self.lr=0.001
        self.k=k
        self.weight_decay=weight_decay
        self.hidden_dim=hidden_dim
        self.device = device
        # self.faster_version=args.faster_version

        self.model_a = encoder_a(self.k, self.hidden_dim, input_dim).to(device)
        self.optimizer_a = torch.optim.Adam(self.model_a.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def train(self, d, trainloader):
        if d <= 40:
            stop_crteria = 0.001
        elif 40 < d and d <= 160:
            stop_crteria = 0.01
        elif 160 < d:
            stop_crteria = 0.01
        else:
            stop_crteria = 0.01

        criterion = nn.CrossEntropyLoss()
        ### training
        epoch_train_times = []
        test_times = []
        for epoch in range(self.num_epochs):
            self.model_a.train()
            running_loss = 0
            epoch_train_time_start = time.perf_counter()
            for i, sample in enumerate(trainloader, 0):
                self.model_a.zero_grad()
                pre_query, _ = sample
                pre_query = pre_query.to(self.device).float()
                pre_query = torch.unsqueeze(pre_query, 1)
                pre_query, positives_matrice = self.model_a(pre_query)
                scores_internal = helper_functions.scores_calc_internal(pre_query, positives_matrice,
                                                                        self.no_negatives,self.temperature, device=self.device).to(self.device)
                scores_internal = scores_internal.permute(0, 2, 1)
                correct_class = torch.zeros((np.shape(scores_internal)[0], np.shape(scores_internal)[2]),
                                            dtype=torch.long).to(self.device)
                loss = criterion(scores_internal, correct_class).to(self.device)
                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.model_a.parameters(), 1)

                self.optimizer_a.step()
                running_loss += loss.item()

            epoch_train_times.append(time.perf_counter() - epoch_train_time_start)
            # print('total train:', np.array(epoch_train_times).mean(), np.array(epoch_train_times).std(),
                  # np.array(epoch_train_times).shape)
            if (running_loss / (i + 1) < stop_crteria):
                break

    def test(self, test_loader):
        ### testing
        self.model_a.eval()
        criterion_test = nn.CrossEntropyLoss(reduction='none')
        # test_losses_contrastloss = torch.zeros(test.shape[0], dtype=torch.float).to(device)
        scores = []
        test_time_start = time.perf_counter()
        with torch.no_grad():
            for i, sample in enumerate(test_loader, 0):
                pre_query, _ = sample
                pre_query = pre_query.to(self.device).float()
                # indexes = sample['index'].to('cpu')
                pre_query_test = torch.unsqueeze(pre_query, 1)  # batch X feature X 1
                pre_query_test, positives_matrice_test = self.model_a(pre_query_test)
                scores_internal_test = helper_functions.scores_calc_internal(pre_query_test, positives_matrice_test,
                                                                             self.no_negatives,self.temperature, device=self.device).to(self.device)
                scores_internal_test = scores_internal_test.permute(0, 2, 1)
                correct_class = torch.zeros((np.shape(scores_internal_test)[0], np.shape(scores_internal_test)[2]),
                                            dtype=torch.long).to(self.device)
                loss_test = criterion_test(scores_internal_test, correct_class).to(self.device)
                scores.append(loss_test.mean(dim=1).to(self.device))
                # test_losses_contrastloss[indexes] +=

        scores = torch.cat(scores)
        return scores.cpu().numpy()

        #
        # return (f1,auc)

