import torch.nn as nn
import torch.nn.functional as F
import os 
import numpy as np
import torch as ch
import random
from torchvision import datasets

import tqdm

from torch.utils.data import TensorDataset, DataLoader
from torchvision.models import resnet18, resnet50
from torch.optim import Adam


def l2_attack(x,y,clf,niters=5,alpha=1, eps=None): 
    x_adv = x.clone().detach()
    x_adv.requires_grad = True
    for i in range(niters): 
        z = clf(x_adv).squeeze(1)
        loss = nn.BCELoss()(z,y)
        g = ch.autograd.grad([loss],[x_adv])[0]
        with ch.no_grad(): 
            g_norm = g.norm(p=2,dim=[1,2,3]).view(g.size(0),1,1,1)
            step = g/ch.clamp(g_norm,min=1e-5)
            x_adv.data = ch.clamp(x_adv + alpha*step,min=0,max=1)
            if eps is not None: 
                x_adv.data = x + (x_adv - x).renorm(2,0,eps)
    return x_adv.detach()
    

def train_clf(loader, clf, weight_decay=5e-4, nepochs=50,lr=0.0005,adv=True,**kwargs): 
    clf.cuda()
    opt = Adam(clf.parameters(), lr=lr, weight_decay=weight_decay)
    pbar = tqdm.tqdm(range(nepochs))
    for i in pbar: 

        for x,y in loader: 
            x,y = x.cuda(), y.cuda()
            if adv: 
                clf.eval()
                x_adv = l2_attack(x,y,clf,**kwargs)
                clf.train()
                
                if i > nepochs/2: 
                    clf.eval()

                z = clf(x_adv).squeeze(1)
            else: 
                if i > nepochs/2: 
                    clf.eval()
                z = clf(x).squeeze(1)
            loss = nn.BCELoss()(z,y.float())
            opt.zero_grad()
            loss.backward()
            opt.step()
            with ch.no_grad(): 
                z = clf(x).squeeze(1)
                acc = ((z > 0.5) == y).float().mean()
                pbar.set_description(f"loss: {loss.item()} acc: {acc.item()}")
    clf.cpu()


class C2ST: 
    def __init__(self, clf): 
        if isinstance(clf, str): 
            if clf == "resnet18": 
                clf = nn.Sequential(resnet18(num_classes=1), nn.Sigmoid())
            elif clf == "resnet50": 
                clf = nn.Sequential(resnet50(num_classes=1), nn.Sigmoid())
            else: 
                raise ValueError("Unknown classifier type")
        self.clf = clf

    def statistic(self, sample_1, sample_2, batch_size=512, **kwargs): 
        y = ch.cat([ch.zeros(sample_1.size(0)), ch.ones(sample_2.size(0))],dim=0)
        X = ch.cat([sample_1,sample_2],dim=0)

        ds = TensorDataset(X,y)
        loader = DataLoader(ds,batch_size=batch_size,shuffle=True)

        self.clf.train()
        train_clf(loader, self.clf, **kwargs)
        self.clf.eval()
        return ((self.clf(X).squeeze(1) > 0.5) == y).float().mean()


if __name__ == '__main__': 
    X1 = ch.randn(32,3,32,32)
    X2 = ch.randn(32,3,32,32)

    c2st = C2ST("resnet18")
    stat = c2st.statistic(X1,X2,alpha=0.1,adv=True)
    print(f"Statistic: {stat.item()}")
