import torch
import numpy as np
from torch import nn

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from tqdm import tqdm

import torch.nn.functional as F
import random
import os

class NTXentLoss(torch.nn.Module):
    def __init__(self, temperature = 0.5, use_cosine_similarity = True):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.use_cosine_similarity = use_cosine_similarity
    def forward(self, reps): #assumes that we have two different "augmentations" after each other each other in the batch dim. So real dim is batch_dim/2
        if self.use_cosine_similarity:
            reps = F.normalize(reps, dim = -1)
        sim_mat = (reps @ reps.T) / self.temperature
        sim_mat.fill_diagonal_(-np.inf) #we cannot predict oursleves.
        batch_size = reps.shape[0]//2
        labels = torch.cat([torch.arange(batch_size)+batch_size, torch.arange(batch_size)]) # positive samples are one batch away 
        labels = labels.to(reps.device)
        return F.cross_entropy(sim_mat, labels)

class BarlowLoss(torch.nn.Module):
    def __init__(self, lambd = 5e-3):
        super(BarlowLoss, self).__init__()
        self.lambd = lambd
    def forward(self, reps): #assumes that we have two different "augmentations" after each other each other in the batch dim. So real dim is batch_dim/2
        batch_size = reps.shape[0]//2
        D = reps.shape[-1]

        a, b = reps[:batch_size], reps[batch_size:]
        a = (a-a.mean(0))/a.std(0)
        b = (b-b.mean(0))/b.std(0)
        cross_co = (a.T @ b) / batch_size

        cross_co_err = (cross_co - torch.eye(D, device = cross_co.device))**2
        idx = torch.arange(D)
        loss = cross_co_err[..., idx, idx].sum() #diag part
        cross_co_err[..., idx, idx] = 0.
        loss += cross_co_err.sum()*self.lambd # off diag part
        return loss

class DeepSet(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128):
        super(DeepSet, self).__init__()
        self.num_outputs = num_outputs
        self.dim_output = dim_output
        self.enc = nn.Sequential(
                nn.Linear(dim_input, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, dim_hidden))
        self.dec = nn.Sequential(
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, num_outputs*dim_output))

    def forward(self, X):
        X = self.enc(X).mean(-2)
        X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output)
        return X

class MixerLayer(nn.Module):
    def __init__(self, dim_in, dim_out, transpose = False):
        super(MixerLayer, self).__init__()
        self.mlp = nn.Sequential(nn.LayerNorm(dim_in), nn.Linear(dim_in, dim_in * 2), nn.GELU(), nn.Linear(dim_in * 2, dim_out))
        self.transpose = transpose
        self.dim_in = dim_in
        self.dim_out = dim_out
    def forward(self, x):
        if self.transpose: x = x.transpose(-1,-2)
        
        ret = self.mlp(x)
        if self.dim_in == self.dim_out: ret += x # make it residual if input == output.

        if self.transpose: ret = ret.transpose(-1,-2)
        
        return ret


class Mixer(nn.Module):
    def __init__(self, dim_input, channels, dim_output):
        super(Mixer, self).__init__()
        self.net = nn.Sequential(
                MixerLayer(dim_input, dim_input,  transpose=False),
                MixerLayer(channels,  channels,   transpose=True),
                MixerLayer(dim_input, dim_input,  transpose=False),
                MixerLayer(channels,  1,          transpose=True),
                MixerLayer(dim_input, dim_output, transpose=False))

    def forward(self, X):
        return self.net(X).mean(-2)


class ContrastiveNetworkDeepset(nn.Module):
    def __init__(self, input_size, projection_head_out_size = 128, emb_size = 256, dim_hidden = 128):
        super(ContrastiveNetworkDeepset, self).__init__()
        self.model_emb = model = nn.Sequential(DeepSet(input_size, 1, emb_size, dim_hidden = dim_hidden), Rearrange("a b c -> a (b c)"))
        self.projection_head = nn.Sequential(
            nn.Linear(emb_size, projection_head_out_size)
        )
        
    def forward(self, x):
        x = self.model_emb(x)
        p = self.projection_head(F.relu(x))
        return x, p

class ContrastiveNetworkMixer(nn.Module):
    def __init__(self, input_size, projection_head_out_size = 128, emb_size = 256, channels = 128):
        super(ContrastiveNetworkMixer, self).__init__()
        self.model_emb = nn.Sequential(Mixer(input_size, channels, emb_size))
        self.projection_head = nn.Sequential(
            nn.BatchNorm1d(emb_size),
            nn.Linear(emb_size, projection_head_out_size)
        )
        
    def forward(self, x):
        x = F.gelu(self.model_emb(x))
        p = self.projection_head(x)
        return x, p

def train_contrastive_network(net, data_set, batch_size = 256, epochs = 10, num_workers = 0, lr = 0.5e-3, val=False, barlow=True):
    net.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr = lr)
    lr_schedule = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.95**epoch)

    val_size = 1028 if val else 1
    train_set, val_set = torch.utils.data.random_split(data_set, [len(data_set)-val_size, val_size])

    train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, drop_last = True, shuffle = True, num_workers=num_workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size = batch_size, drop_last = True, shuffle = True, pin_memory=True)

    if barlow: 
        criterion = BarlowLoss()
    else:
        criterion = NTXentLoss(temperature = 0.1)

    losses = []
    val_losses = []

    for epoch in range(epochs):  
        for i, data in enumerate(train_loader, 0):

            net.train()
            data = data.cuda()
            optimizer.zero_grad()

            _, projs = net(rearrange(data, "a b ... -> (b a) ..."))
            loss = criterion(projs)
            loss.backward()
            optimizer.step()

            losses += [loss.item()]
            print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, losses[-1]))

        val_loss = 0
        for i, data in enumerate(val_loader, 0):
            net.eval()
            data = data.cuda()
            _, projs = net(rearrange(data, "a b ... -> (b a) ..."))
            val_loss += criterion(projs).item() / len(val_loader)

        print("val_loss:", val_loss)
        val_losses+=[val_loss]
        lr_schedule.step()

def get_embs(net, data_set, indices = None):
    net.eval()
    embs = []

    if indices is None:
        indices = np.arange(len(data_set))
    
    for idx in tqdm(indices):
        e, _ = net(torch.tensor(data_set[idx]).cuda())
        embs += [e.detach().cpu().numpy()]

    return np.array(embs)
