from utils import *
from torch.utils.data import Dataset
import torch
import random
import os
from tqdm import trange
import pickle
import numpy as np
from tempfile import TemporaryFile


class InContextDataset(Dataset):
    def __init__(self, N, num_example, num_var, start_curriculum=1):
        """
        N: number of samples
        num_example: number of examples in one input
        num_var: number of variables in the Markov Chain
        """
        self.N = N
        self.num_example = num_example
        self.num_var = num_var
        self.graphs = []
        self.mask_out_range = [0, start_curriculum]

        if os.path.exists(f"./dataset/markov_chain_{self.num_var}.npy"):
            with open(f"./dataset/markov_chain_{self.num_var}.npy", 'rb') as f:
                self.data = torch.from_numpy(np.load(f))
                self.graphs.append(get_markov_chain(num_var=self.num_var))

        else:
            self.data = []
            for i in trange(1000):
                self.graphs.append(get_markov_chain(num_var=self.num_var))
          
            for i in trange(1000):
                examples = self.graphs[i].to_onehot(self.graphs[i].sample(2000)) # (4000, -1)
                self.data.append(examples)
            self.graphs = [get_markov_chain_test(num_var=self.num_var)]
            self.data = torch.stack(self.data, dim=0)
            npy = self.data.numpy()
            with open(f"./dataset/markov_chain_{self.num_var}.npy", 'wb') as f:
                np.save(f, npy)
            
    def update_mode(self, mode):
        self.mode = mode

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        
        graph_idx = random.randint(0, 999)
        sampled_idx = torch.randint(0, 2000, (self.num_example+1,))
        example_idx = sampled_idx[:self.num_example]
        test_idx = sampled_idx[-1]
        examples = self.data[graph_idx][example_idx]
        # examples = self.graphs[graph_idx].to_onehot(self.graphs[graph_idx].sample(self.num_example))
        test_token = self.data[graph_idx][test_idx].unsqueeze(0).argmax(-1)
        mask_out = random.randint(self.mask_out_range[0], self.mask_out_range[1])

        # attention_mask is 1.0 for positions we want to attend and 0.0 for masked positions
        # attn_mask = torch.ones(  )
      
        y = test_token[:, mask_out]
        test_token = self.graphs[0].to_onehot(test_token)

        test_token = self.graphs[0].mask_var(test_token, mask_out)
        test_token = test_token.view(test_token.size(0), -1)
        pos_enc = torch.zeros(examples.size(0), self.num_var)
        examples = examples.view(examples.size(0), -1)
        examples = torch.cat([examples, pos_enc], dim=-1)
        return torch.cat([examples, test_token], dim=0), y
    

class InContextDatasetTest(Dataset):
    def __init__(self, N, num_example, mask_out, num_var):
        """
        N: number of samples
        num_example: number of examples in one input
        """

        self.N = N
        self.num_example = num_example
        self.mask_out = mask_out
        self.num_var = num_var

        self.graphs = []

        if os.path.exists(f"./dataset/test_markov_{num_var}.npy"):
            self.graphs.append(get_markov_chain_test(num_var=self.num_var))
            with open(f"./dataset/test_markov_{num_var}.npy", 'rb') as f:
                self.data = torch.from_numpy(np.load(f)) # (1, 50000, 4)
        else:
            self.data = []
            for i in trange(1):
                self.graphs.append(get_markov_chain_test(num_var=self.num_var))

            for i in trange(1):
                examples = self.graphs[i].to_onehot(self.graphs[i].sample(50000)) # (50000, 4, 2)
                self.data.append(examples)
            self.data = torch.stack(self.data, dim=0)
            npy = self.data.numpy()
            with open(f"./dataset/test_markov_{num_var}.npy", 'wb') as f:
                np.save(f, npy)

    def update_mask_out(self, m):
        self.mask_out = m
      
    def update_num_example(self, m):
        self.num_example = m
  
    def __len__(self):
        return self.N

    def __getitem__(self, idx):

        # torch.manual_seed(idx)
        # random.seed(idx)
        graph_idx = 0
        sampled_idx = torch.randint(0, 50000, (self.num_example+1,))
        example_idx = sampled_idx[:self.num_example]
        test_idx = sampled_idx[-1]
        examples = self.data[graph_idx][example_idx]
        # examples = self.graphs[graph_idx].to_onehot(self.graphs[graph_idx].sample(self.num_example))
        test_token = self.data[graph_idx][test_idx].argmax(-1).unsqueeze(0)

        mask = torch.tensor([[True]*self.num_var])
        mask[:, self.mask_out:] = False

        X_masked = torch.masked.MaskedTensor(test_token, mask=mask)
        graph_pred = self.graphs[graph_idx].graph.predict(X_masked)[:, self.mask_out].squeeze(0)

        mask_out = self.mask_out
        y = test_token[:, mask_out]
        test_token = self.graphs[graph_idx].to_onehot(test_token)
        test_token = self.graphs[0].mask_var(test_token, mask_out)
        probs = self.graphs[0].get_prob(mask_out)
        test_token = test_token.view(test_token.size(0), -1)
      
        pos_enc = torch.zeros(examples.size(0), self.num_var)
        examples = examples.view(examples.size(0), -1)
        examples = torch.cat([examples, pos_enc], dim=-1)
        return torch.cat([examples, test_token], dim=0), y, probs, graph_pred