import torch
from torch import nn
import numpy as np

from utils.miscs import Tensor2Numpy, sliced_wasserstein_distance
from model.NF import RealNVP


class Classifier(nn.Module):
    def __init__(self, dim, hid_dim=128):
        super().__init__()
        self.dim = dim
        self.net = nn.Sequential(nn.Linear(dim*2, hid_dim), nn.ELU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ELU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ELU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
        
    def forward(self, x, y, weight):
        batch_size = x.shape[0]
        logits = self.net(torch.cat([x, y], dim=-1))
        
        loss_cls = 1. / batch_size * torch.log(self.final(logits)) + weight / batch_size * torch.log(self.final(-logits))
        loss_cls = - loss_cls.sum()
        return loss_cls

    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            logits_batch = self.net(torch.cat([x_batch, y_batch], dim=-1))
            logits.append(logits_batch)

        logits = torch.cat(logits, dim=0)
        weight = (-logits).exp()
        updated_weight = Tensor2Numpy(weight.view(-1)) # (len)
        return updated_weight

class ClassifierMSDense(nn.Module):
    def __init__(self, dim, scaling_num=10, hid_dim=512):
        super().__init__()
        self.dim = dim
        self.scaling_num = scaling_num
        self.B = torch.FloatTensor(np.arange(scaling_num)+1).view(-1, 1) # (scaling_num, 1)
        self.net = nn.Sequential(nn.Linear(scaling_num*dim*2, hid_dim), nn.ELU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ELU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ELU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
    
    def forward(self, x, y, weight):
        # print('weight:', weight)
        batch_size = x.shape[0]
        input = (torch.cat([x, y], dim=-1
                           )[..., None, :] * self.B.to(x.device)).reshape(batch_size, -1)
        logits = self.net(input)
        
        loss = 1. / batch_size * torch.log(self.final(logits)) + weight / batch_size * torch.log(self.final(-logits))
        loss = - loss.sum()
        return loss

    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            input = (torch.cat([x_batch, y_batch], dim=-1
                           )[..., None, :] * self.B.to(x.device)
                     ).reshape(y_batch.shape[0], -1)
            logits_batch = self.net(input)
            logits.append(Tensor2Numpy(logits_batch))
        try:
            logits = np.concatenate(logits, axis=0)
        except ValueError:
            logits = np.array(logits)
        weight = np.exp(-logits)
        updated_weight = weight.reshape(-1) # (len)
        return updated_weight

class ClassifierMSDense2(nn.Module):
    def __init__(self, dim, scaling_num=10, hid_dim=512):
        super().__init__()
        self.dim = dim
        self.scaling_num = scaling_num
        self.B = torch.FloatTensor(np.arange(scaling_num)+1).view(-1, 1) # (scaling_num, 1)
        self.net = nn.Sequential(nn.Linear(scaling_num*dim*4, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
    
    def forward(self, x, y, weight):
        batch_size = x.shape[0]
        input = (torch.cat([x, y], dim=-1
                           )[..., None, :] * self.B.to(x.device)).reshape(batch_size, -1)
        logits = self.net(torch.cat([torch.sin(input), torch.cos(input)], dim=-1))
        
        loss = 1. / batch_size * torch.log(self.final(logits)) + weight / batch_size * torch.log(self.final(-logits))
        loss = - loss.sum()
        return loss

    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            input = (torch.cat([x_batch, y_batch], dim=-1
                           )[..., None, :] * self.B.to(x.device)
                     ).reshape(y_batch.shape[0], -1)
            logits_batch = self.net(torch.cat([torch.sin(input), torch.cos(input)], dim=-1))
            logits.append(Tensor2Numpy(logits_batch))
        try:
            logits = np.concatenate(logits, axis=0)
        except ValueError:
            logits = np.array(logits)
        weight = np.exp(-logits)
        updated_weight = weight.reshape(-1) # (len)
        return updated_weight

class ClassifierFEDense(nn.Module):
    def __init__(self, dim, frequency_num=64, frequence_sigma=10, hid_dim=64):
        super().__init__()
        self.dim = dim
        self.B = torch.randn(frequency_num, dim*2) * frequence_sigma
        self.net = nn.Sequential(nn.Linear(2*frequency_num+dim*2, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
    
    def forward(self, x, y, weight):
        batch_size = x.shape[0]
        input = torch.cat([x, y], dim=-1)
        Bx = torch.sum(self.B.to(x.device)*input[...,None,:], -1)
        input = torch.cat([input, torch.sin(Bx), torch.cos(Bx)], dim=-1)
        logits = self.net(input)
        
        loss = 1. / batch_size * torch.log(self.final(logits)) + weight / batch_size * torch.log(self.final(-logits))
        loss = - loss.sum()
        return loss

    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            input = torch.cat([x_batch, y_batch], dim=-1)
            Bx = torch.sum(self.B.to(x_batch.device)*input[...,None,:], -1)
            input = torch.cat([input, torch.sin(Bx), torch.cos(Bx)], dim=-1)
            logits_batch = self.net(input)
            logits.append(Tensor2Numpy(logits_batch))

        logits = np.concatenate(logits, axis=0)
        weight = np.exp(-logits)
        updated_weight = weight.reshape(-1) # (len)
        return updated_weight

class ClassifierFEDense2(nn.Module):
    def __init__(self, dim, frequency_num=32, frequence_sigma=10, hid_dim=64):
        super().__init__()
        self.dim = dim
        self.frequency_num = frequency_num
        self.B = torch.randn(frequency_num, dim*2) * frequence_sigma
        self.net = nn.Sequential(nn.Linear(2*frequency_num*dim*2+dim*2, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
    
    def forward(self, x, y, weight):
        batch_size = x.shape[0]
        input = torch.cat([x, y], dim=-1)
        Bx = (self.B.to(x.device)*input[...,None,:]).reshape(batch_size, -1)
        input = torch.cat([input, torch.sin(Bx), torch.cos(Bx)], dim=-1)
        logits = self.net(input)
        
        loss = 1. / batch_size * torch.log(self.final(logits)) + weight / batch_size * torch.log(self.final(-logits))
        loss = - loss.sum()
        return loss

    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            input = torch.cat([x_batch, y_batch], dim=-1)
            Bx = (self.B.to(x_batch.device)*input[...,None,:]).reshape(-1, self.frequency_num*self.dim*2)
            input = torch.cat([input, torch.sin(Bx), torch.cos(Bx)], dim=-1)
            logits_batch = self.net(input)
            logits.append(Tensor2Numpy(logits_batch))

        logits = np.concatenate(logits, axis=0)
        weight = np.exp(-logits)
        updated_weight = weight.reshape(-1) # (len)
        return updated_weight


class ClassifierNF(nn.Module):
    def __init__(self, dim, hid_dim=128):
        super().__init__()
        self.dim = dim
        self.nf_net = RealNVP(
            n_blocks=5, input_size=2*dim, 
            hidden_size=100, n_hidden=1
        )
        self.cls_net = nn.Sequential(nn.Linear(dim*2, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
        self.alpha = 0.5
        
    def forward(self, x, y, weight):
        z, sum_log_det_jacobians = self.nf_net(torch.cat([x, y], dim=-1))
        logits = self.cls_net(z)
        
        loss_cls = torch.log(self.final(logits)) + weight * torch.log(self.final(-logits))
        loss_cls = - loss_cls.mean()
        l_loss = -sum_log_det_jacobians.mean()
        loss = loss_cls + self.alpha * l_loss
        return loss
    
    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            embedding_batch, _ = self.nf_net(torch.cat([x_batch, y_batch], dim=-1))
            logits_batch = self.cls_net(embedding_batch)
            logits.append(logits_batch)

        logits = torch.cat(logits, dim=0)
        weight = (-logits).exp()
        updated_weight = Tensor2Numpy(weight.view(-1)) # (len)
        return updated_weight

class ClassifierLatent(nn.Module):
    def __init__(self, dim, hid_dim=128):
        super().__init__()
        self.dim = dim
        self.embedding_net = nn.Sequential(nn.Linear(dim*2, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, dim*2))
        self.cls_net = nn.Sequential(nn.Linear(dim*2, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, hid_dim), nn.ReLU(),
                                 nn.Linear(hid_dim, 1))
        self.final = nn.Sigmoid()
        self.alpha = 0.5
        
    def forward(self, x, y, weight):
        z = self.embedding_net(torch.cat([x, y], dim=-1))
        logits = self.cls_net(z)
        
        loss_cls = torch.log(self.final(logits)) + weight * torch.log(self.final(-logits))
        loss_cls = - loss_cls.mean()
        l_loss = sliced_wasserstein_distance(z, torch.randn_like(z), projection_num=50, p=2, device=x.device)
        loss = loss_cls + self.alpha * l_loss
        return loss
    
    def update_weight(self, x, y):
        batch_size = 10000
        N = x.shape[0]
        logits = []
        
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            x_batch = x[start:end]
            y_batch = y[start:end]
            embedding_batch = self.embedding_net(torch.cat([x_batch, y_batch], dim=-1))
            logits_batch = self.cls_net(embedding_batch)
            logits.append(logits_batch)

        logits = torch.cat(logits, dim=0)
        weight = (-logits).exp()
        updated_weight = Tensor2Numpy(weight.view(-1)) # (len)
        return updated_weight
    