{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e03bfa86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from sklearn.manifold import TSNE\n",
    "import tensorflow as tf\n",
    "from ts2vec import TS2Vec\n",
    "import datautils\n",
    "from tools import MinMaxScaler\n",
    "import pickle\n",
    "from scipy.linalg import sqrtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c8c1c007",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'sine_cpx'\n",
    "full_train_data = np.load('../datasets/'+dataset+'.npy')\n",
    "N, T, D = full_train_data.shape\n",
    "valid_perc = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2404522c",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_train = int(N * (1 - valid_perc))\n",
    "N_valid = N - N_train\n",
    "np.random.shuffle(full_train_data)\n",
    "train_data = full_train_data[:N_train]\n",
    "valid_data = full_train_data[N_train:]\n",
    "scaler = MinMaxScaler()        \n",
    "x_train = scaler.fit_transform(train_data)\n",
    "x_valid = scaler.transform(valid_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4d22d3f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_gen = np.load('../save_model/gen_'+ dataset + '.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe39a8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_fid(act1, act2):\n",
    "    # calculate mean and covariance statistics\n",
    "    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)\n",
    "    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)\n",
    "    # calculate sum squared difference between means\n",
    "    ssdiff = np.sum((mu1 - mu2)**2.0)\n",
    "    # calculate sqrt of product between cov\n",
    "    covmean = sqrtm(sigma1.dot(sigma2))\n",
    "    # check and correct imaginary numbers from sqrt\n",
    "    if np.iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    # calculate score\n",
    "    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "    return fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c5a112d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = dict(\n",
    "        batch_size=8,\n",
    "        lr=0.001,\n",
    "        output_dims=320,\n",
    "        max_train_length=3000\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b96c5311",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg:1.8230335689306258±0.08929151793637233\n"
     ]
    }
   ],
   "source": [
    "fid_s = []\n",
    "for i in range(5):\n",
    "    model = TS2Vec(\n",
    "        input_dims=x_train.shape[-1],\n",
    "        device=0,\n",
    "        **config\n",
    "    )\n",
    "    model.fit(x_train, verbose=False)\n",
    "    ori_repr = model.encode(x_train, encoding_window='full_series')\n",
    "    gen_repr = model.encode(x_gen, encoding_window='full_series')\n",
    "    select = x_gen.shape[0]\n",
    "    idx = np.random.permutation(select)\n",
    "    ori = ori_repr[idx]\n",
    "    gen = gen_repr[idx]\n",
    "    fid_s.append(calculate_fid(ori, gen))\n",
    "print(\"Avg:{}\\xB1{}\".format(np.mean(fid_s), 1.96*(np.std(fid_s)/np.sqrt(len(fid_s)))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
