import plot
import torch
import time

from experiments.Experiment import Experiment
from models.LinearModel import LinearMapModel
from models.TokenDatasets import TokenDataset


class TokenExperiment(Experiment):
    def __init__(self, settings):
        super().__init__(settings)
        self.history_length = settings['history_length']
        self.history_step = 1

    def get_dataset(self, solver):
        return TokenDataset(solver.m_u, self.history_length, self.keep_every, self.device)

    def get_model(self, solver):
        return LinearMapModel(solver, self.history_length).to(self.device)

    def sample(self, solver, checkpoint):
        self.prepare_solver(solver)
        model = self.get_model(solver)
        model.load(checkpoint)

        offset = 500

        new_data = self.get_dataset(solver)
        tensor = torch.empty((self.T_a, new_data.raw_data.shape[1]), dtype=torch.float64, device=self.device, requires_grad=False)

        start_time = time.time()
        for i in range(self.T_a):
            if i < self.history_length:
                tensor[i] = new_data.raw_data[i + offset]
            else:
                tokens = tensor[i - self.history_length:i].flatten()
                tensor[i] = model(tokens)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time:.6f} seconds")
        print(f"Time per timestep: {elapsed_time / self.T_a:.6f} seconds")

        real = new_data.raw_data[offset : offset + self.T_a]
        real_res = torch.from_numpy(solver.u[offset : offset + self.T_a]).view(self.T_a, solver.grid_dimension, solver.grid_dimension)
        diff = tensor.cpu() - real

        plot.save_tensor(diff, f'generated/{solver.name}/diff_tok', frame_cut=False)
        plot.save_tensor(real, f'generated/{solver.name}/real_tok', frame_cut=False)
        plot.save_tensor(tensor, f'generated/{solver.name}/gen_tok', frame_cut=False)
        plot.save_tensor(real_res, f'generated/{solver.name}/real_full_tok', frame_cut=False)

        return tensor.cpu(), real, diff

    def save_graphs(self, solver, dataset, model):
        raw_m_u = dataset.raw_data
        timesteps = self.T_a

        print(f'Saving real solution...')
        shape = (timesteps, solver.output_dimension, solver.output_dimension)
        real_m_u = raw_m_u.cpu().detach().numpy()[:timesteps].reshape(shape)
        plot.animate(
            real_m_u,
            solver.delta_t, self.T_a,
            solver.seed,
            'm_u_real',
            speed_up=self.speed_up
        )

        generated_m_u = torch.empty(raw_m_u.shape).type(torch.float64).to(self.device)
        generated_m_u = generated_m_u[:timesteps]
        generated_m_u[:self.history_length] = raw_m_u[:self.history_length]

        for i in range(timesteps - self.history_length):
            tokens = generated_m_u[i:i + self.history_length].flatten()
            generated_m_u[i + self.history_length] = model(tokens)

        print(f'Saving generated solution...')
        g_m_u_reshaped = generated_m_u.cpu().detach().numpy().reshape(shape)
        plot.animate(g_m_u_reshaped, solver.delta_t, self.T_a, solver.seed, f'model_m_u', speed_up=self.speed_up)

        print(f'Saving model error gif...')
        abs_model_error = g_m_u_reshaped - real_m_u
        plot.animate(abs_model_error, solver.delta_t, self.T_a, solver.seed, f'model_error', speed_up=self.speed_up)
        print("a")
