import numpy as np
import pickle
from scipy.special import expit

import torch
from torch.utils.data import Dataset, DataLoader

import lightning.pytorch as L
from lightning.pytorch.utilities.combined_loader import CombinedLoader

class _Dataset:
    def __init__(self, i = None):
        with open('data/split_preference/hidden_state/Skywork-Reward-Llama-3.1-8B-v0.2.pkl', 'rb') as f:
            hidden_state = pickle.load(f)

        self.weight = hidden_state['weight'].to(torch.float32)
        self.x1 = hidden_state['hidden_states_1'].to(torch.float32)
        self.x2 = hidden_state['hidden_states_2'].to(torch.float32)

        # with open('data/split_preference/reward/Skywork-Reward-Gemma-2-27B-v0.2.pkl', 'rb') as f:
        with open('data/split_preference/reward/URM-LLaMa-3.1-8B.pkl', 'rb') as f:
            reward = pickle.load(f)

        self.r1 = reward['rewards_1'].to(torch.float32)
        self.r2 = reward['rewards_2'].to(torch.float32)

        self.z = self.r1 - self.r2

        self.y = torch.bernoulli(torch.sigmoid(self.z)).to(torch.float32)

        if i is not None:
            self.x1 = self.x1[i]
            self.x2 = self.x2[i]
            self.r1 = self.r1[i]
            self.r2 = self.r2[i]
            self.z = self.z[i]
            self.y = self.y[i]
            self.i = torch.tensor(i, dtype=torch.int64)
        else:
            self.i = torch.arange(len(self.x1))
        
    def __len__(self):
            return len(self.x1)

    def __getitem__(self, idx):
        return {
            "x1": self.x1[idx],
            "x2": self.x2[idx],
            "r1": self.r1[idx],
            "r2": self.r2[idx],
            "z": self.z[idx],
            "y": self.y[idx],
            "idx": self.i[idx]
        }

class LLMPreferenceDataModule(L.LightningDataModule):
    def __init__(self, rng: np.random.Generator, n: int, batch_size: int):
        super().__init__()
        self.batch_size = batch_size

        n = 9600
        n4 = 9600 // 4

        idx = rng.choice(9691, n, replace=True)
        self.train_datasets = {f"data{i+1}": _Dataset(idx[(i*n4):((i+1)*n4)]) for i in range(4)}

    def train_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=True) for name, ds in self.train_datasets.items()}
        return CombinedLoader(loaders, mode="min_size")
    
    def predict_dataloader(self):
        loaders = {name: DataLoader(ds, batch_size=self.batch_size, shuffle=False) for name, ds in self.train_datasets.items()}
        return CombinedLoader(loaders, mode="sequential")


