import os
import torch
from sentence_transformers import SentenceTransformer
import pandas as pd
import re

ml100k_path = '/home/data/ml-100k/ml-100k/'   
processed_path = '/home/data/ml-100k/processed/t5/'  
device = "cpu"
model = SentenceTransformer('/home/data/llm/sentence-t5-base', device=device)


# 1. ratings
def pre_process_rating():
    os.makedirs(processed_path, exist_ok=True)

    ratings = pd.read_csv(
        os.path.join(ml100k_path, 'u.data'),
        sep='\t',
        header=None,
        names=['UserID', 'MovieID', 'Rating', 'Timestamp']
    )

    unique_users = ratings['UserID'].unique()
    unique_items = ratings['MovieID'].unique()

    user2id = {u: i for i, u in enumerate(unique_users)}
    item2id = {m: i for i, m in enumerate(unique_items)}

    ratings['UserID'] = ratings['UserID'].map(user2id)
    ratings['MovieID'] = ratings['MovieID'].map(item2id)

    print(f"Users: {len(user2id)}, Items: {len(item2id)}, Interactions: {len(ratings)}")

    rating_data = ratings.values  # numpy array

    grouped = ratings.groupby('UserID')
    rating_data_fl = {u: g.values for u, g in grouped}
    users_inter = {u: g.values[:, 1] for u, g in grouped}

    grouped_m = ratings.groupby('MovieID')
    items_inter = {m: g.values[:, 0] for m, g in grouped_m}

    torch.save(users_inter, os.path.join(processed_path, 'graph_user.pth'))
    torch.save(items_inter, os.path.join(processed_path, 'graph_item.pth'))
    torch.save(rating_data, os.path.join(processed_path, 'ratings.pth'))
    torch.save(rating_data_fl, os.path.join(processed_path, 'ratings_fl.pth'))

    print("finish ratings processing")

    return user2id, item2id


# 2. movies
def pre_process_movies(item2id):
    def process_genres(genres):
        split_genres = genres.replace('|', ',')
        return ", belongs to the " + split_genres + " genres."

    def process_title(title, year):
        return f"A Movie '{title}', released in {year}"

    movies = pd.read_csv(
        os.path.join(ml100k_path, 'u.item'),
        sep='|',
        encoding='latin1',
        header=None,
        names=['MovieID', 'Title', 'ReleaseDate', 'VideoReleaseDate', 'IMDbURL'] + [f'Genre_{i}' for i in range(19)]
    )

    genre_names = [
        "unknown", "Action", "Adventure", "Animation", "Children's", "Comedy", "Crime", "Documentary",
        "Drama", "Fantasy", "Film-Noir", "Horror", "Musical", "Mystery", "Romance", "Sci-Fi",
        "Thriller", "War", "Western"
    ]
    movies['Genres'] = movies.iloc[:, 5:].apply(
        lambda row: '|'.join([genre_names[i] for i, g in enumerate(row) if g == 1]), axis=1
    )

    movies = movies[movies['MovieID'].isin(item2id.keys())]

    movies_data = [""] * len(item2id)
    for _, movie in movies.iterrows():
        new_id = item2id[movie['MovieID']]
        year = movie['ReleaseDate'][-4:] if pd.notna(movie['ReleaseDate']) else "unknown"
        movies_data[new_id] = process_title(str(movie['Title']), year) + process_genres(str(movie['Genres']))

    embedding = model.encode(movies_data, convert_to_tensor=True)
    print(embedding.shape)
    torch.save(embedding, os.path.join(processed_path, 'items.pth'))
    print("finish movies processing")


# 3. users
def pre_process_users(user2id):
    gender_feature = {'F': 'female', 'M': 'male'}
    age_map = {
        1: "under 18",
        18: "18-24",
        25: "25-34",
        35: "35-44",
        45: "45-49",
        50: "50-55",
        56: "56+"
    }

    users = pd.read_csv(
        os.path.join(ml100k_path, 'u.user'),
        sep='|',
        header=None,
        names=['UserID', 'Age', 'Gender', 'Occupation', 'Zip-code']
    )

    users = users[users['UserID'].isin(user2id.keys())]

    users_data = [""] * len(user2id)
    for _, user in users.iterrows():
        new_id = user2id[user['UserID']]
        gender = gender_feature[user['Gender']]
        age = age_map[int(user['Age'])] if int(user['Age']) in age_map else "unknown age"
        occ = str(user['Occupation'])
        users_data[new_id] = f"A {gender} user, aged {age}, with occupation of {occ}."

    embedding = model.encode(users_data, convert_to_tensor=True)
    print(embedding.shape)
    torch.save(embedding, os.path.join(processed_path, 'users.pth'))
    print("finish users processing")


if __name__ == "__main__":
    user2id, item2id = pre_process_rating()
    pre_process_movies(item2id)
    pre_process_users(user2id)

