import cebra
import torch
import numpy as np
import itertools


class CEBRAforMovie:
    def __init__(self, model_architecture, conditional, time_offsets, max_iterations, batch_size, learning_rate, output_dimension, num_hidden_units, verbose):
        self.model_architecture = model_architecture
        self.conditional = conditional
        self.time_offsets = time_offsets
        self.max_iterations = max_iterations
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.output_dimension = output_dimension
        self.num_hidden_units = num_hidden_units
        self.verbose = verbose

        self.device = "cuda"
    
    def fit(self, dataset):
        self.data_loader = cebra.data.ContinuousDataLoader(
            dataset,
            num_steps=self.max_iterations,
            batch_size=self.batch_size,
            conditional=self.conditional,
            time_offset=self.time_offsets
        )

        self.data_loader.to(self.device)
        self.model = cebra.models.init(self.model_architecture, self.data_loader.dataset.input_dimension, self.num_hidden_units, self.output_dimension, True).to(self.device)
        self.data_loader.dataset.configure_for(self.model)
        self.criterion = cebra.models.InfoNCE(temperature=1)
        self.optimizer = torch.optim.Adam(itertools.chain(self.model.parameters(), self.criterion.parameters()), lr=self.learning_rate)

        self.solver = cebra.solver.SingleSessionSolver(
            model=self.model,
            criterion=self.criterion,
            optimizer=self.optimizer,
            tqdm_on=self.verbose
        )

        self.solver.fit(self.data_loader)
    
    @torch.no_grad()
    def transform(self, dataset):
        dataset.configure_for(self.model)
        return self.model(dataset[torch.arange(len(dataset))].to(self.device)).cpu().numpy()

    def save(self, path):
        torch.save(self.model.state_dict(), path)
    
    def load(self, input_dimension, path):
        self.model = cebra.models.init(self.model_architecture, input_dimension, self.num_hidden_units, self.output_dimension, True).to(self.device)
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.to(self.device)
