import plot
import torch
import time

import numpy as np

from experiments.Experiment import Experiment
from models.LinearModel import LinearMapModel
from models.TokenDatasets import TokenDataset


class ResolutionExperiment(Experiment):
    def __init__(self, settings):
        super().__init__(settings)
        self.history_length = settings['res_history_length']
        self.history_step = 2

    def get_dataset(self, solver):
        dataset = TokenDataset(solver.m_u, self.history_length, self.keep_every, self.device)
        dataset.set_resolution_labels(solver.u)
        return dataset

    def get_model(self, solver):
        return LinearMapModel(solver, self.history_length, resolution=True).to(self.device)

    def sample(self, solver, checkpoint, data_path=None):
        model = self.get_model(solver)
        model.load(checkpoint)

        tensor = torch.empty((self.T_a, solver.grid_dimension ** 2), dtype=torch.float64, device=self.device, requires_grad=False)
        if data_path is None:
            raise ValueError("Data must be provided for sampling.")
        data = torch.load(data_path + 'real_tok.pt').to(torch.float64)
        half = self.history_length // 2

        start_time = time.time()
        for i in range(half, self.T_a + half):
            tokens = data[i - half : i + half + 1].flatten().to(self.device)
            tensor[i - half] = model(tokens)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time:.6f} seconds")
        print(f"Time per timestep: {elapsed_time / self.T_a:.6f} seconds")

        real = torch.load(data_path + 'real_full_tok.pt').to(torch.float64)[half : half + self.T_a]
        tensor = tensor.view(self.T_a, solver.grid_dimension, solver.grid_dimension)
        diff = tensor.cpu() - real

        plot.save_tensor(diff, f'generated/{solver.name}/diff_res', frame_cut=False)
        plot.save_tensor(real, f'generated/{solver.name}/real_res', frame_cut=False)
        plot.save_tensor(tensor, f'generated/{solver.name}/gen_res', frame_cut=False)

        plot.animate(real.cpu().detach().numpy(), solver.delta_t, self.T_a, solver.seed, f'{solver.name}/real_res', speed_up=8)
        plot.animate(tensor.cpu().detach().numpy(), solver.delta_t, self.T_a, solver.seed, f'{solver.name}/gen_res', speed_up=8)

        return tensor.cpu(), real, diff

    def save_graphs(self, solver, dataset, model):
        generations = solver.timesteps // self.T_a
        u = solver.u

        print(f'Saving real solution...')
        shape = (solver.timesteps, solver.grid_dimension, solver.grid_dimension)
        middle = self.history_length // 2
        plot.animate(u.reshape(shape), solver.delta_t, self.T_a, solver.seed, 'u', speed_up=self.speed_up)

        for j in range(generations):
            generated_u = np.empty(u.shape, dtype=np.float64)
            generated_u[:middle] = u[j * self.T_a:j * self.T_a + middle]
            # Insert initial data
            for i in range(self.T_a):
                tokens, label = dataset[j * self.T_a + i]
                generated_u[i + middle] = model(tokens).cpu().detach().numpy()
            generated_u[-middle:] = u[-middle:]

            print(f'Saving generated solution [{j + 1}/{generations}]...')
            g_u_reshaped = generated_u.reshape(shape)
            plot.animate(g_u_reshaped, solver.delta_t, self.T_a, solver.seed, f'model_{j}_u', speed_up=self.speed_up)

            print(f'Saving model error gif [{j + 1}/{generations}]...')
            sliced_raw_m_u = np.empty(u.shape, dtype=np.float64)
            sliced_raw_m_u[:self.T_a] = u[j * self.T_a:(j + 1) * self.T_a]
            abs_model_error = (generated_u - sliced_raw_m_u).reshape(shape)
            plot.animate(abs_model_error, solver.delta_t, self.T_a, solver.seed, f'model_{j}_error', speed_up=self.speed_up)
