import numpy as np
import torch
import torch.nn as nn
import pickle as pk 
from retrain.src.seq_dataset import SeqDataset


T = 15
offline_t = 7
size_of_datasets = 20
seed = 2
trial_index = 0

example_path_name = 'retrain/experiment_data/wild/{0}_seed_{1}_{2}_{3}_{4}.pk'.format(trial_index, seed, T, size_of_datasets, offline_t)

def create_fake_dataset_in_my_format(T:int, offline_t:int, trial_index:int, size_of_datasets:int):
    """_summary_

    Args:
        T (int): horizon
        offline_t (int): timeindex split between offline and online period
        trial_index (int): 
        size_of_datasets (int): num of samples per full dataset (val+test+train)

    Returns:
        seqdataset: a sequence of dataset
    """
    # do something random based on the trial_index
    X_sequence = [np.random.rand(size_of_datasets, 3) for _ in range(T)]  
    Y_sequence = [np.random.randint(0, 2, (size_of_datasets, 1)) for _ in range(T)]  


    seqdataset = SeqDataset(X_sequence, Y_sequence, offline_t=offline_t,test_split=0.33, val_split=0.1,seed=seed)
    return seqdataset

seqdataset = create_fake_dataset_in_my_format(T, offline_t, trial_index,size_of_datasets)

# dummy torch network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(3, 3)
    def forward(self, x):
        return self.fc(x)
   
trained_models = []
loss_dict_matrix  ={}


# Convert datasets and labels to PyTorch tensors
datasets = [torch.tensor(data, dtype=torch.float32) for data in seqdataset.X_sequence]
labels = [torch.tensor(label, dtype=torch.float32) for label in seqdataset.Y_sequence]
    
for i in range(T):
    loss_dict_matrix[i] = {}
    # Create a new model for each dataset
    model = SimpleNet()
    # TODO actually train it on dataset i
    print('training model', i , '...')
    
    
    
    # Eval it on the right datasets
    for j in range(i,T):
        print('evaluating model', i , 'on dataset', j)
        # TODO compute the acc of the model on the test split of dataset j
        acc = 0.8
        loss = 1-acc
        loss_dict_matrix[i][j] = loss
        
    checkpoint_path = 'retrain/ckpt/wild/{0}_seed_model_{1}.pk'.format(trial_index, i)
    torch.save(model, checkpoint_path)
    # Store the path to the trained model to be able to retreive it later
    dict_trained_f = {i: checkpoint_path}


dict_of_things_needed = {
            'loss_dict_matrix': loss_dict_matrix,
            'dict_trained_f': dict_trained_f,
            'seqdataset' :seqdataset
            }


with open(example_path_name, 'wb') as file:
    pk.dump(dict_of_things_needed, file)

