import os
import matplotlib
#Comment out if not on notebook
import torch
from torchvision.models import resnet18

from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance, FeatureCost
from otdd.pytorch.flows import OTDD_Gradient_Flow
from otdd.pytorch.flows import CallbackList, ImageGridCallback, TrajectoryDump
import time
from pathlib import Path

# custom dataset to load in our preprocess data
class CustomTensorDataset(torch.utils.data.Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensor, transform=None):
        self.tensors = tensor
        self.x = tensor[0]
        self.y = tensor[1]
        self.targets = tensor[1]
        self.transform = transform

    def __getitem__(self, index):
        x = self.x[index]

        if self.transform:
            x = self.transform(x)

        y = self.y[index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)
    
    
# Load in our datasets in otdd framework
device = "cpu"
data_path = "../data/FMNIST_MNIST.tar"
data = torch.load(data_path,map_location=torch.device(device))
source_data = data["source_data"].to(device).float().reshape((-1,1,20,20))
target_data = data["target_data"].to(device).float().reshape((-1,1,20,20))
clustered_labels = torch.from_numpy(data["labels"]).to(device)

source_labels = torch.zeros(300)
target_labels = torch.zeros(50)
source_data_200 = torch.zeros((300,1,20,20))
target_data_50 = torch.zeros((50,1,20,20))

for i in range(10):
    source_data_200[i*30:(i+1)*30] = source_data[i*200:i*200+30]
    target_data_50[i*5:(i+1)*5] = target_data[i*200:i*200+5]
    source_labels[i*30:(i+1)*30]=i
    target_labels[i*5:(i+1)*5] = i

source_data_200 = source_data_200.to(device)
target_data_50 = target_data_50.to(device)
source_labels = source_labels.to(device)
target_labels =target_labels.to(device)

loaders_src = CustomTensorDataset((source_data_200, source_labels))
loaders_tgt = CustomTensorDataset((target_data_50, target_labels))

# entropy regularization is a hidden parameter in their code, we added it here.
ent_reg = 0.001

# by our experiments, we find only method='xyaugm' generates reasonable images
outdir =  os.path.join('outF_M_xyaugm', 'flows')
Path(outdir).mkdir(parents=True, exist_ok=True)

callbacks = CallbackList([
  ImageGridCallback(display_freq=1, animate=False, save_path = outdir + '/grid'),
])

flow = OTDD_Gradient_Flow(loaders_src, loaders_tgt,
                          ### Gradient Flow Args
                          method = 'xyaugm',
                          use_torchoptim=True,
                          optim='adam',
                          steps=100,
                          step_size=0.1,
                          callback=callbacks,
                          clustering_method='kmeans',
                          ### OTDD Args
                          online_stats=True,
                          diagonal_cov = False,
                          device='cpu',
                          path = outdir,
                          entreg = ent_reg
                          )
d,out = flow.flow()
