import torch
import plot
import time

import numpy as np

import learn.analytical as analytical
import learn.ml as ml

from models.Generator import Generator, SuperResGenerator
from models.TokenDatasets import TokenConcatDataset
from models.LinearModel import get_error
from models.AdversarialDataset import AdversarialDataset
from models.GANDataset import TensorDataset
from tqdm import tqdm

from solvers.HeatSolver import HeatSolver


class Experiment:
    def __init__(self, settings):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.keep_every = settings['keep_every']
        self.sample_size = settings['sample_size']
        self.speed_up = settings['speed_up']
        self.epochs = settings['epochs']
        self.lr = settings['lr']
        self.batch_size = settings['batch_size']
        self.save = settings['save']
        self.T_a = settings['T_a']
        self.mask_dimension = settings['c']
        self.data_dir = settings['data_dir']
        self.epsilon = settings['epsilon']
        self.settings = settings

        self.history_length = None
        self.history_step = None
        print(f'Device: {self.device}')

    def get_dataset(self, solver):
        raise NotImplementedError

    def get_model(self, solver):
        raise NotImplementedError

    def save_graphs(self, solver, dataset, model):
        raise NotImplementedError

    def sample(self, solver, checkpoint):
        raise NotImplementedError

    def prepare_solver(self, solver):
        solver.calculate_fields()

    def get_multiple_datasets(self, Solver, track_solvers=False):
        datasets = []
        solvers = []
        it = self.settings['sample_size']
        op_powers = None
        stacked_ops = None
        op = None
        for _ in tqdm(range(it), desc="Generating datasets", total=it, unit="iter"): # Keep op_seed the same
            solver = Solver(self.settings)
            solver.op = op
            solver.op_powers = op_powers
            solver.stacked_ops = stacked_ops
            self.prepare_solver(solver)
            self.settings['seed'] += 1

            if not track_solvers:
                dataset = self.get_dataset(solver)
                datasets.append(dataset)

            solvers.append(solver)
            op_powers = solver.op_powers
            stacked_ops = solver.stacked_ops
            # op = op_powers[self.keep_every]
            op = solver.op # TODO: This is a temporary fix, for frame skip of 1
        if track_solvers:
            return solvers, op
        else:
            return datasets, op

    def save_multiple_datasets(self, Solver, name):
        solvers, op = self.get_multiple_datasets(Solver, track_solvers=True)
        frame_cut_proportion = self.settings['T'] // self.settings['keep_every'] + 1
        for solver in solvers:
            u = solver.u.reshape((solver.timesteps, solver.grid_dimension, solver.grid_dimension))
            plot.save_tensor(u, f'{name}/{solver.op_seed}/{solver.seed}', frame_cut_proportion=frame_cut_proportion)
        plot.save_tensor(np.expand_dims(op, axis=0), f'{name}/{solvers[0].op_seed}/op', frame_cut_proportion=frame_cut_proportion)


    def lin_alg(self, solver):
        self.prepare_solver(solver)
        dataset = self.get_dataset(solver)
        params, resid, err = analytical.lin_alg(dataset, solver)

        print(resid, err)


    def torch_learn(self, solver):
        self.prepare_solver(solver)
        dataset = self.get_dataset(solver)
        model = self.get_model(solver)

        print('Theoretical best:')
        analytical.linear_regression(dataset)

        loss_curves, err_curve = ml.torch_train(model, dataset, self.epochs, self.lr)

        print('Saving loss curves and config file...')
        plot.save_loss_curves(loss_curves, solver.seed)

        if self.save:
            self.save_graphs(solver, dataset, model)


    def torch_learn_multiple(self, Solver):
        solver = Solver(self.settings)
        model = self.get_model(solver)
        datasets, op = self.get_multiple_datasets(Solver)
        dataset = TokenConcatDataset(datasets, self.device)
        del datasets, op
        torch.cuda.empty_cache()

        print('Theoretical best:')
        analytical.linear_regression(dataset)

        loss_curves = ml.torch_train(model, dataset, self.epochs, self.lr)
        params = model.model.weight.data.cpu().numpy()
        # a = op - params

        print('Saving loss curves and config file...')
        plot.save_loss_curves(loss_curves, solver.seed)

        if self.save:
            self.save_graphs(solver, dataset, model)


    def adversary_learn(self, solver):
        assert isinstance(solver, HeatSolver)

        self.prepare_solver(solver)
        dataset = AdversarialDataset(solver.u[0], solver.timesteps, solver.op, solver.mask, self.history_length, self.device)
        model = self.get_model(solver)

        loss_curves = ml.adversarial_train(model, dataset, solver, self)

        print('Saving loss curves and config file...')
        plot.save_loss_curves(loss_curves, solver.seed)

        if self.save:
            self.save_graphs(solver, None, model)

    def save_random_graphs(self, solver):
        model = self.get_model(solver)
        model.load(solver.seed)

        solver.seed = solver.seed // 2  # Make sure we get a new, unseen dataset
        self.prepare_solver(solver)
        dataset = self.get_dataset(solver)

        self.save_graphs(solver, dataset, model)

    def vary_history(self, Solver, analytic, save=True):
        max_history_length = self.history_length
        solver = Solver(self.settings)
        datasets, op = self.get_multiple_datasets(Solver)
        curves = []
        for history_length in range(1, max_history_length + 1, self.history_step):
            print("History length: ", history_length)
            self.history_length = history_length
            for dataset in datasets:
                dataset.reset_token_labels(history_length)
            if analytic:
                params, resid, err = analytical.linear_regression(TokenConcatDataset(datasets, self.device))
                err = [err]
            else:
                model = self.get_model(solver)
                loss_curves, err = ml.torch_train(model, datasets, self.epochs, self.lr, save=save)
            curves.append(err)

        if save:
            plot.plot_err_curves(curves, solver.seed)
            plot.plot_err_against_length(curves, solver.seed)
        else:
            return curves

    # This function varies parameters to do with the dataset (which is only history_length)
    def vary_history_many(self, Solver, analytic, iterations=100):
        max_history_length = self.history_length
        curve_collection = []
        for i in range(iterations):
            print("Starting iteration ", i+1)
            curve_collection.append(self.vary_history(Solver, analytic, save=False))
            self.settings['seed'] += 1
        np.savez(f'output/history_length_{max_history_length}_{Solver.name}.npz', *curve_collection)
        plot.plot_sample_size_err(
            curve_collection,
            max_history_length // self.history_step,
            'history_length',
            f'history_length_{max_history_length}_{Solver.name}_size_{iterations}'
        )

    # This function varies parameters to do with the solver (of which there are many)
    def vary_solver_parameter(self, Solver, parameter, values):
        values = np.array(values)
        curve_collection = []
        for _ in range(self.sample_size):
            curve = []
            for i in values:
                print("Parameter", parameter, ":", i)
                self.settings[parameter] = i

                solver = Solver(self.settings)
                self.prepare_solver(solver)
                dataset = self.get_dataset(solver)

                const, params, resid, err_curve = analytical.linear_regression(dataset)
                curve.append([err_curve])
            self.settings['seed'] += 1
            curve_collection.append(curve)
        plot.plot_sample_size_err(
            curve_collection,
            len(values),
            parameter,
            f'{parameter}_{values.min()}_{values.max()}_{self.sample_size}'
        )

    def alternate_initial_conditions_torch(self, solver, seed):
        self.prepare_solver(solver)
        dataset = self.get_dataset(solver)
        model = self.get_model(solver)
        model.load(seed)

        x, y = dataset.data, dataset.labels
        y = y.cpu().detach().numpy()
        y_gen = model(x).cpu().detach().numpy()
        err = model.get_error(y, y_gen)
        print("New err: ", err)

    def alternate_initial_conditions(self, solver, Solver, seed):
        self.prepare_solver(solver)
        dataset = self.get_dataset(solver)
        params, resid, err = analytical.linear_regression(dataset)

        self.settings['seed'] = seed
        new_solver = Solver(self.settings)
        self.prepare_solver(new_solver)
        new_dataset = self.get_dataset(new_solver)

        x = new_dataset.data.cpu().detach().numpy()[:params.shape[0]]
        y = new_dataset.labels.cpu().detach().numpy()[:params.shape[0]]
        # TODO: Doesn't account for different op_seed!!
        solver.seed = seed
        y_gen = params @ x
        err = get_error(y, y_gen)
        print("New err: ", err)

    def linear_regression_multiple(self, Solver):
        datasets, op = self.get_multiple_datasets(Solver)

        params, resid, err = analytical.linear_regression(TokenConcatDataset(datasets, self.device))
        print(resid, err)

    def train_simplified_gan(self, seed, tokens=True):
        dataset = TensorDataset(self.data_dir, self.mask_dimension, self.history_length, self.device, tokens=tokens)
        ml.quasi_gan_train(dataset, self.epochs, self.lr, self.batch_size, seed, tokens=tokens)

    def train_ks_ar(self):
        dataset = TensorDataset(self.data_dir, self.mask_dimension, self.history_length, self.device)
        ml.ar_train(dataset, self.epochs, self.lr, self.batch_size)

    def train_linear_gan(self, Solver):
        solver = Solver(self.settings)
        model = self.get_model(solver)
        datasets, op = self.get_multiple_datasets(Solver)
        dataset = TokenConcatDataset(datasets, self.device)

        ml.gan_train(model, dataset, self.epochs, self.lr)

    def sample_ks_many(self, epoch, checkpoint, method='gan'):
        dataset = TensorDataset(self.data_dir, self.mask_dimension, 16, self.device)
        generator = Generator(1, 1).to(dataset.device)
        res_generator = SuperResGenerator(7, 1, 1).to(self.device)

        epoch, optimizer_g, optimizer_v = generator.load(epoch, method)
        epoch, optimizer_g, optimizer_v = res_generator.load(checkpoint, method=method)

        for j in range(dataset.masked_raw.size(0)):
            # Low res section
            self.history_length = 16
            tensor = torch.empty((self.T_a + self.history_length, 1, 256 // self.mask_dimension, 256 // self.mask_dimension), device=dataset.device, requires_grad=False)
            tensor[:self.history_length] = dataset.masked_raw[j, :self.history_length].unsqueeze(1)
            for i in range(self.T_a):
                with torch.no_grad():
                    tensor[i + self.history_length] = generator(tensor[i:i + self.history_length].transpose(0, 1).unsqueeze(0)).transpose(0, 1)

            # Great, now we have a low res tensor. Create the high res tensor
            self.history_length = 7
            res_tensor = torch.empty((self.T_a, 256, 256), device=self.device, requires_grad=False)
            half = self.history_length // 2
            for i in range(half, self.T_a + half):
                with torch.no_grad():
                    res_tensor[i - half] = res_generator(tensor[i - half: i + half + 1]
                                             .transpose(0, 1).unsqueeze(0)).transpose(0, 1).squeeze(0)

            plot.save_tensor(res_tensor, f'generated/kse/for_cors/res_gen_{epoch}_{j}', frame_cut=False)

            real = dataset.raw_data[j][half: self.T_a + half].clone().to(torch.float32)
            plot.save_tensor(real, f'generated/kse/for_cors/res_real_{epoch}_{j}', frame_cut=False)

    def sample_ks(self, epoch, method='gan'):
        dataset = TensorDataset(self.data_dir, self.mask_dimension, self.history_length, self.device)
        generator = Generator(1, 1).to(dataset.device)
        epoch, optimizer_g, optimizer_v = generator.load(epoch, method)

        tensor = torch.empty((self.T_a + self.history_length, 1, 256 // self.mask_dimension, 256 // self.mask_dimension), device=dataset.device, requires_grad=False)
        tensor[:self.history_length] = dataset.masked_raw[0, :self.history_length].unsqueeze(1)

        start_time = time.time()
        for i in range(self.T_a):
            tensor[i + self.history_length] = generator(tensor[i:i + self.history_length].transpose(0, 1).unsqueeze(0)).transpose(0, 1)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time:.6f} seconds")
        print(f"Time per timestep: {elapsed_time / self.T_a:.6f} seconds")
        plot.save_tensor(tensor.squeeze(1), f'generated/kse/gen_{epoch}', frame_cut=False)
        plot.save_video(tensor.squeeze(1).detach().cpu().numpy(), 'data/generated/kse/', f'gen_{epoch}', fps=60)

        real = dataset.masked_raw[0, :self.T_a + self.history_length].unsqueeze(1)
        plot.save_tensor(real.squeeze(1), f'generated/kse/real_{epoch}', frame_cut=False)
        plot.save_video(real.squeeze(1).detach().cpu().numpy(), 'data/generated/kse/', f'real_{epoch}', fps=60)

        diff = tensor.cpu() - real
        plot.save_tensor(diff.squeeze(1), f'generated/kse/diff_{epoch}', frame_cut=False)
        plot.save_video(diff.squeeze(1).detach().cpu().numpy(), 'data/generated/kse/', f'diff_{epoch}', fps=60)

    def save_real_ks(self, name):
        dataset = TensorDataset(self.data_dir, self.mask_dimension, self.history_length, self.device)
        dataset.save_example(self.T_a, name)

    def sample_ks_res(self, checkpoint, data, method='gan'):
        dataset = TensorDataset(self.data_dir, self.mask_dimension, self.history_length, self.device)
        data = torch.load(data).unsqueeze(1).to(self.device)
        generator = SuperResGenerator(self.history_length, 1, 1).to(self.device)
        epoch, optimizer_g, optimizer_v = generator.load(checkpoint, method=method)

        tensor = torch.empty((self.T_a, 256, 256), device=self.device, requires_grad=False)
        half = self.history_length // 2
        for i in range(half, self.T_a + half):
            tensor[i - half] = generator(data[i - half : i + half + 1].transpose(0, 1).unsqueeze(0)).transpose(0, 1).squeeze(0)

        plot.save_tensor(tensor, f'generated/kse/res_gen_{epoch}', frame_cut=False)
        plot.save_video(tensor.detach().cpu().numpy(), 'data/generated/kse/', f'res_gen_{epoch}', fps=60)

        real = dataset.raw_data[0][half : self.T_a + half]
        plot.save_tensor(real, f'generated/kse/res_real_{epoch}', frame_cut=False)
        plot.save_video(real.detach().cpu().numpy(), 'data/generated/kse/', f'res_real_{epoch}', fps=60)

        diff = tensor.cpu() - real
        plot.save_tensor(diff, f'generated/kse/res_diff_{epoch}', frame_cut=False)
        plot.save_video(diff.detach().cpu().numpy(), 'data/generated/kse/', f'res_diff_{epoch}', fps=60)
