from typing import Optional
import torch
from umfavi.data.preference_dataset import PreferenceDataset
from umfavi.data.demonstration_dataset import DemonstrationDataset
from umfavi.data.rating_dataset import RatingDataset
from umfavi.data.ranking_dataset import RankingDataset
from umfavi.data.stop_dataset import StopDataset
from umfavi.data.utils import derive_seed
from umfavi.types import FeedbackType
from umfavi.utils.policies import QValueModel
from torch.utils.data import DataLoader

def make_dataset(
    active_feedback_types,
    args,
    make_env_fn,
    policies,
    device,
    obs_transform,
    action_transform,
    name: Optional[str] = "train",
    q_true: Optional[QValueModel] = None,
):
    datasets = {}
    dataloaders = {}
    g = torch.Generator()
    g.manual_seed(args.seed)
    if FeedbackType.PREFERENCE in active_feedback_types:
        num_samples = args.n_pref_samples if name == "train" else max(int(args.n_pref_samples * 1.0), 1)
        base_seed = derive_seed(args.seed, FeedbackType.PREFERENCE.value, name)
        pref_dataset = PreferenceDataset(
            base_seed=base_seed,
            num_episodes=args.n_pref_episodes,
            num_pref_pairs=num_samples,
            policy=policies[FeedbackType.PREFERENCE],
            make_env_fn=make_env_fn,
            device=device,
            beta=args.pref_rationality,
            gamma=args.gamma,
            obs_transform=obs_transform,
            act_transform=action_transform,
            segment_len=args.pref_seg_len,
            name=name,
            min_reward_threshold=args.min_reward_pref,
            td_error_weight=args.td_error_weight,
        )
        datasets[FeedbackType.PREFERENCE] = pref_dataset
        dataloaders[FeedbackType.PREFERENCE] = DataLoader(pref_dataset, batch_size=args.batch_size, shuffle=True, generator=g)
        print(f"Created preference dataset with {len(pref_dataset)} samples")

    if FeedbackType.DEMONSTRATION in active_feedback_types:
        num_samples = args.n_demo_samples if name == "train" else max(int(args.n_demo_samples * 1.0), 1)
        base_seed = derive_seed(args.seed, FeedbackType.DEMONSTRATION.value, name)
        demo_dataset = DemonstrationDataset(
            base_seed=base_seed,
            num_demonstrations=num_samples,
            num_steps=None,
            make_env_fn=make_env_fn,
            policy=policies[FeedbackType.DEMONSTRATION],
            device=device,
            beta=args.demo_rationality,
            gamma=args.gamma,
            td_error_weight=args.td_error_weight,
            obs_transform=obs_transform,
            act_transform=action_transform,
            name=name,
            step_offset=args.step_offset,
            min_reward_threshold=args.min_reward_demo,
            subsample_factor=args.subsample_factor
        )
        datasets[FeedbackType.DEMONSTRATION] = demo_dataset
        dataloaders[FeedbackType.DEMONSTRATION] = DataLoader(demo_dataset, batch_size=args.batch_size, shuffle=True, generator=g)
        print(f"Created demonstration dataset with {len(demo_dataset)} samples")

    if FeedbackType.RATING in active_feedback_types:
        num_samples = args.n_rating_samples if name == "train" else max(int(args.n_rating_samples * 1.0), 1)
        base_seed = derive_seed(args.seed, FeedbackType.RATING.value, name)
        rating_dataset = RatingDataset(
            base_seed=base_seed,
            num_episodes=args.n_rating_episodes,
            num_samples=num_samples,
            policy=policies[FeedbackType.RATING],
            make_env_fn=make_env_fn,
            device=device,
            num_categories=5,  # Standard Likert scale
            gamma=args.gamma,
            obs_transform=obs_transform,
            act_transform=action_transform,
            segment_len=args.rating_seg_len,
            name=name,
            min_reward_threshold=args.min_reward_rating,
            td_error_weight=args.td_error_weight,
        )
        datasets[FeedbackType.RATING] = rating_dataset
        dataloaders[FeedbackType.RATING] = DataLoader(rating_dataset, batch_size=args.batch_size, shuffle=True, generator=g)
        print(f"Created rating dataset with {len(rating_dataset)} samples")

    if FeedbackType.RANKING in active_feedback_types:
        num_samples = args.n_ranking_samples if name == "train" else max(int(args.n_ranking_samples * 1.0), 1)
        base_seed = derive_seed(args.seed, FeedbackType.RANKING.value, name)
        ranking_dataset = RankingDataset(
            base_seed=base_seed,
            num_episodes=args.n_ranking_episodes,
            num_ranking_samples=num_samples,
            num_ranked_items=args.num_ranked_items,
            policy=policies[FeedbackType.RANKING],
            make_env_fn=make_env_fn,
            device=device,
            beta=args.ranking_rationality,
            gamma=args.gamma,
            obs_transform=obs_transform,
            act_transform=action_transform,
            segment_len=args.ranking_seg_len,
            name=name,
            min_reward_threshold=args.min_reward_ranking,
            td_error_weight=args.td_error_weight,
        )
        datasets[FeedbackType.RANKING] = ranking_dataset
        dataloaders[FeedbackType.RANKING] = DataLoader(ranking_dataset, batch_size=args.batch_size, shuffle=True, generator=g)
        print(f"Created ranking dataset with {len(ranking_dataset)} samples")

    if FeedbackType.STOP in active_feedback_types:
        if q_true is None:
            raise ValueError("q_model must be provided for stop feedback")
        num_samples = args.n_stop_samples if name == "train" else max(int(args.n_stop_samples * 1.0), 1)
        base_seed = derive_seed(args.seed, FeedbackType.STOP.value, name)
        stop_dataset = StopDataset(
            base_seed=base_seed,
            num_episodes=args.n_stop_episodes,
            num_samples=num_samples,
            segment_len=args.stop_seg_len,
            q_model=q_true,
            policy=policies[FeedbackType.STOP],
            make_env_fn=make_env_fn,
            device=device,
            c=args.stop_c,
            regret_percentile=args.stop_regret_percentile,
            regret_discount=args.stop_regret_discount,
            gamma=args.gamma,
            obs_transform=obs_transform,
            act_transform=action_transform,
            name=name,
            td_error_weight=args.td_error_weight,
        )
        datasets[FeedbackType.STOP] = stop_dataset
        dataloaders[FeedbackType.STOP] = DataLoader(stop_dataset, batch_size=args.batch_size, shuffle=True, generator=g)
        print(f"Created stop dataset with {len(stop_dataset)} samples")

    return datasets, dataloaders