import os
import torch
from utils import set_seed, select_device
from datasets import Dataset
from market.specifications import Specification

class User:
    def __init__(self, cfg, user_id):
        print(f'Creating user {user_id}...', end='\r')
        set_seed(cfg['seed'])
        self.cfg = cfg
        self.user_id = user_id
        self.dataset = Dataset(cfg, 'user', user_id)
        kwargs = self.spec_kwargs()
        self.spec = Specification(cfg, 'user', user_id, **kwargs)
        self.task_labels()

    def spec_kwargs(self):
        spec = self.cfg['specification']
        return {
            'path': os.path.join(self.cfg['dataset_path'], 'specifications', spec, 'user', f'{self.user_id}.npz'),
            'device': select_device(self.cfg, self.user_id),
            'phi_path': os.path.join(self.cfg['dataset_path'], 'phi', 'user', f'{self.user_id}.pt')  # For neural phi
        }

    def task_labels(self):
        label_path = os.path.join('logs', 'user_task_labels', f'{self.user_id}.pt')
        if not os.path.exists(label_path):
            os.makedirs(os.path.dirname(label_path), exist_ok=True)
            user_task = self.dataset.get_data('test')
            torch.save(user_task[1], label_path)