from re import M
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import math
import numpy as np
import os
import torchvision.transforms as T
import ot

whereIam = os.uname()[1]
print(whereIam)

HOME = 



class OT_Loss(nn.Module):
    def __init__(self, all_models, target_classe, type = "sliced", p=2, n_projections = 5000, layer=None):
        # type in ["sliced", "exact", "canonical_1D"]
        super(OT_Loss, self).__init__()
        # self.reduction = lambda x : ((x**2).sum(dim=(2,3)) + 1e-6).sqrt()
        self.target = target_classe
        self.targets = {}
        self.target_quantile = {}
        self.weights = {}
        self.type = type
        self.p = p
        self.n_projections = n_projections
        self.layer = layer

        if self.type in ["exact", "sliced"] : 
            for model in all_models:
                self.targets[model] = torch.load(HOME + f"data/IMAGENET/all_images_feature/sort_by_class/{self.layer}/{model}.pt", map_location="cpu")[self.target].cuda()
                
                if model in ["resnet50", "resnet34", "resnet50_relu_adv", "resnet50_self", "resnet50_v2", "efficientnet_b0"]:
                    tutu = torch.load(HOME + f"data/IMAGENET/stats_train/{self.layer}/imagenet_tf_{model}.pt", map_location="cpu")
                    self.mean = tutu.mean(0).cuda()
                    self.std = tutu.std(0).cuda()

    def forward(self, features, model):

        print(features.size())

        if model in ["swin_t", "swin_s", "swin_b"]:
            if self.layer in ["layer1", "layer2", "layer3"]:
                features = features.permute(0,3,1,2)

        print(features.size())
        
        if model in ["deit_s", "deit_s_adv", "deit_t", "deit_b"]:
            print("postnorm attention")
            features = features#[:,0] #if pre-norm, need to select the class token
        else: 
            features = features.mean((2,3))
        # print(features.size())


        if self.type == "exact" : 
            targets = self.targets[model]

            #FOR CNNS
            if model in ["resnet50", "resnet34", "resnet50_relu_adv", "resnet50_self", "resnet50_v2", "efficientnet_b0"]:
                targets = (targets - self.mean)/self.std
                features = (features - self.mean)/self.std


            C = torch.zeros((features.size(0), targets.size(0)), device="cuda")

            for b in range(features.size(0)):
                C[b,:] = ((features[b] - targets).abs()**self.p).sum(1)

            # print(C[0].min())


            mu = torch.ones((features.size(0)), device="cuda")/features.size(0)
            nu = torch.ones((targets.size(0)), device="cuda")/targets.size(0)

            loss = ot.emd2(mu, nu, C)


        elif self.type == "sliced":
            targets = self.targets[model]
            print(targets.size())

            #FOR CNNS
            if model in ["resnet50", "resnet34", "resnet50_relu_adv", "resnet50_self", "resnet50_v2", "efficientnet_b0"]:
                print("normalization")
                targets = (targets - self.mean)/self.std
                features = (features - self.mean)/self.std

            loss = ot.sliced_wasserstein_distance(features, targets, a=None, b=None,
                                            n_projections=self.n_projections, p=self.p, projections=None, seed=None, log=False)

        else :
            print("Error OT type not conformed")
            return 

        return loss



#---------------------------------------------------------------------
# Smoothness loss function
#---------------------------------------------------------------------
def Smoothness_loss(patch):
    device = patch.device.type
    p_h, p_w = patch.shape[-2:]
    # TODO Renormalize to avoid numerical problems
    if torch.max(patch) > 1:
        patch = patch / 255
    diff_w = torch.square(patch[:, :, :-1, :] - patch[:, :, 1:, :])
    zeros_w = torch.zeros((1, 3, 1, p_w), device=device)
    diff_h = torch.square(patch[:, :, :, :-1] - patch[:, :, :, 1:])
    zeros_h = torch.zeros((1, 3, p_h, 1), device=device)
    return torch.sum(torch.cat((diff_w, zeros_w), dim=2) + torch.cat((diff_h, zeros_h), dim=3))


