{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import h5py\n",
    "from scipy import sparse\n",
    "\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "UAM_MATLAB_FILE = '../Source_Data/LFM/LFM-1b/LFM-1b_LEs.mat'\n",
    "user_file = '../Source_Data/LFM/LFM-1b/LFM-1b_users.txt'\n",
    "\n",
    "# Read the user-artist-matrix and corresponding artist and user indices from Matlab file\n",
    "def read_UAM(m_file):\n",
    "    mf = h5py.File(m_file, 'r')\n",
    "    user_ids = np.array(mf.get('idx_users')).astype(np.int64)\n",
    "    artist_ids = np.array(mf.get('idx_artists')).astype(np.int64)\n",
    "    # Load UAM\n",
    "    UAM = sparse.csr_matrix((mf['/LEs/'][\"data\"],\n",
    "                             mf['/LEs/'][\"ir\"],\n",
    "                             mf['/LEs/'][\"jc\"])).transpose()    #.tocoo().transpose()\n",
    "    # user and artist indices to access UAM\n",
    "    UAM_user_idx = UAM.indices #UAM.row -> for COO matrix\n",
    "    UAM_artist_idx = UAM.indptr #UAM.col -> for COO matrix\n",
    "    return UAM, UAM_user_idx, UAM_artist_idx, user_ids, artist_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top100 = [\"J. Cole\",\"Pentatonix\",\"The Weeknd\",\"Bruno Mars\",\"Drake\",\"Ariana Grande\",\"twenty one pilots\",\"Taylor Swift\",\"Shawn Mendes\",\"Rae Sremmurd\",\n",
    "    \"Rihanna\",\"The Chainsmokers\",\"Garth Brooks\",\"Sundance Head\",\"Zayn\",\"Michael Buble\",\"Adele\",\"Justin Bieber\",\"Alessia Cara\",\"Post Malone\",\n",
    "    \"Metallica\",\"Trans-Siberian Orchestra\",\"Elvis Presley\",\"Mariah Carey\",\"Sia\",\"Lady Gaga\",\"Maroon 5\",\"John Legend\",\"Keith Urban\",\"Blake Shelton\",\n",
    "    \"Billy Gilman\",\"Niall Horan\",\"Louis Tomlinson\",\"DJ Snake\",\"Gucci Mane\",\"The Rolling Stones\",\"Amine\",\"Lil Uzi Vert\",\"Bing Crosby\",\"Brett Eldredge\",\n",
    "    \"Kendrick Lamar\",\"Nicki Minaj\",\"Zay Hilfigerrr & Zayion McCall\",\"Frank Sinatra\",\"Camila Cabello\",\"Halsey\",\"Daft Punk\",\"BTS\",\"Florida Georgia Line\",\"Jordan Smith\",\n",
    "    \"Future\",\"Chris Stapleton\",\"X Ambassadors\",\"Jon Bellion\",\"Carrie Underwood\",\"Tech N9ne\",\"Thomas Rhett\",\"Beyonce\",\"Miranda Lambert\",\"Andy Williams\",\n",
    "    \"Fifth Harmony\",\"Chris Tomlin\",\"Childish Gambino\",\"Tim McGraw\",\"Justin Timberlake\",\"Josh Gallagher\",\"gnash\",\"Shelley FKA DRAM\",\"Burl Ives\",\"Amy Grant\",\n",
    "    \"Lauren Daigle\",\"Kidz Bop Kids\",\"Steve Aoki\",\"Kanye West\",\"Shakira\",\"Machine Gun Kelly\",\"Eminem\",\"Daya\",\"Big Sean\",\"Calvin Harris\",\n",
    "    \"Vince Guaraldi Trio\",\"DJ Khaled\",\"Bryson Tiller\", \"Little Big Town\", \"G-Eazy\", \"21 Savage\", \"Lil Yachty\", \"Martin Garrix\", \"Kenny Chesney\", \"Nat King Cole\",\n",
    "    \"Young M.A\",\"Johnny Mathis\",\"Bebe Rexha\",\"Brenda Lee\",\"Maren Morris\",\"Desiigner\",\"Old Dominion\",\"Carpenters\",\"Mannheim Steamroller\",\"Jason Aldean\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirty_artist_gender_file = '../Source_Data/LFM/baseline-data/LFM1b-MB-artists.txt'\n",
    "cols = ['id', 'artist', 'gender_dist']\n",
    "artist = pd.read_csv(dirty_artist_gender_file, sep='\\t', names=cols)[cols[:2]]\n",
    "\n",
    "artist100 = pd.DataFrame(columns=artist.columns)\n",
    "for i in range(len(top100)):\n",
    "    if top100[i] in artist['artist'].values:\n",
    "        artist_row = artist[artist['artist'] == top100[i]]\n",
    "        artist100 = pd.concat([artist100, artist_row], ignore_index=True)\n",
    "    else:\n",
    "        continue\n",
    "\n",
    "artist100 = artist100.drop_duplicates('artist')\n",
    "user_df = pd.read_csv(user_file, sep='\\t')[['user_id', 'gender']]\n",
    "filtered_user_df = user_df[(user_df['gender'] == 'm') | (user_df['gender'] == 'f')]\n",
    "user_ids = np.array(filtered_user_df['user_id'])\n",
    "\n",
    "UAM, UAM_user_idx, UAM_artist_idx, _, artist_ids = read_UAM(UAM_MATLAB_FILE)\n",
    "indices = [list(artist_ids.flatten()).index(x) if x in list(artist_ids.flatten()) else None for x in artist100['id'].to_list()]\n",
    "artist_idx = [element for element in indices if element is not None][:80]\n",
    "artist_name_all = artist100['artist'].to_list()\n",
    "artist_name = [artist_name_all[i] for i in range(len(indices)) if indices[i] is not None][:80]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(columns=artist_idx)\n",
    "df['user_id'] = []\n",
    "df = df.set_index('user_id')\n",
    "print('Users: ', len(user_ids))\n",
    "print('Artists: ', len(artist_idx))\n",
    "\n",
    "for i in range(0, 10000):\n",
    "    idx = 0\n",
    "    idx_list = []\n",
    "    while idx not in idx_list:\n",
    "        idx = random.randint(0, len(user_ids))\n",
    "        pc_i = UAM.getrow(idx).toarray().flatten()\n",
    "        df.loc[user_ids[idx]] = pc_i[artist_idx]\n",
    "        idx_list.append(idx)\n",
    "\n",
    "# df.to_csv('../Source_Data/LFM/LFM2000.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LFM = df\n",
    "\n",
    "row_sums = LFM.sum(axis=1)\n",
    "rows_to_remove = row_sums[row_sums < 400].index\n",
    "LFM = LFM.drop(index=rows_to_remove)\n",
    "# LFM = LFM.loc[:, LFM.sum() >= 50000]\n",
    "\n",
    "LFM = LFM.merge(user_df, how='inner', left_on='user_id', right_on='user_id')\n",
    "LFM.iloc[:, 1:-1] = (LFM.iloc[:, 1:-1] > 0).astype(int)\n",
    "\n",
    "# LFM.to_csv('../Source_Data/LFM/LFM_binary.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "sys.path.insert(0, '../GRAPH_Framework-main')\n",
    "from tasks.experiment import ModelTest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LFM = pd.read_csv('../Source_Data/LFM/LFM_binary.csv')\n",
    "\n",
    "male = LFM.loc[LFM['gender'] == 'm']\n",
    "female = LFM.loc[LFM['gender'] == 'f']\n",
    "\n",
    "D = np.array(LFM[list(LFM.columns[2:-1])])\n",
    "D1 = np.array(male[list(LFM.columns[2:-1])])\n",
    "D2 = np.array(female[list(LFM.columns[2:-1])])\n",
    "\n",
    "D_lt = [D1, D2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {'max_iter':100, 'step_size':1e-2, 'lam':1e-5, 'lamm':1e-5, 'rhom':1, 'tol':1e-5}\n",
    "ising_test = ModelTest(model_type='Ising',normalization=False,showfig=True)\n",
    "ising_test.group_graph(D,D_lt,parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ising_test.runtime(1,D,D_lt,parameters)\n",
    "ising_test.summary()\n",
    "ising_test.plot()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Graph_Learning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
