import json
import numpy as np
import pathlib
import torch
import torch.nn as nn
import torch.optim as optim

class Trainer:
    def __init__(self, model, train_data, test_data, n_epochs=2500, lr=0.001, weight_decay=0, feedback=False, feedback_freq=0, compute_all_metrics=True, record_errors=False, test_freq=100, save_freq=100, path="results", device="cuda"):
        self.model = model
        self.train_data = train_data
        self.test_data = test_data
        self.n_epochs = n_epochs
        self.lr = lr
        self.weight_decay = weight_decay
        self.feedback = feedback
        self.model.feedback_freq = feedback_freq
        self.compute_all_metrics = compute_all_metrics
        self.record_errors = record_errors
        self.test_freq = test_freq
        self.save_freq = save_freq
        self.path = path
        self.device = device

        self.model.to(self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def train(self, start_epoch=0):
        epoch = start_epoch
        self.train_metrics = dict()
        self.test_metrics = dict()
        feedback_inputs = None
        aux = None

        for batch in self.train_data:
            epoch += 1
            if epoch > self.n_epochs:
                break

            self.model.train()
            self.optimizer.zero_grad()

            data = batch["data"].to(self.device)
            init_state = batch["init_state"].to(self.device)
            targets = batch["targets"].to(self.device)

            if self.feedback:
                feedback_inputs = targets

            _, _, outputs = self.model(data, init_state=init_state, feedback=feedback_inputs)

            if self.compute_all_metrics:
                aux = batch

            train_loss, train_metric = self.model.task.compute_metrics(outputs, targets, aux)
            train_loss.backward()
            self.optimizer.step()
            self.train_metrics[epoch] = train_metric.copy()

            print(f"Epoch {epoch} (train):")
            for k, v in train_metric.items():
                print(f"  - {k} = {v}.")

            if self.record_errors:
                with torch.no_grad():
                    errors = targets - outputs
                    errors_path = pathlib.Path(self.path).joinpath(f"errors_{epoch}.npy")
                    np.save(errors_path, errors.cpu().detach().numpy())
                print(f"Errors recorded at epoch {epoch}.")

            if epoch % self.save_freq == 0:
                model_path = pathlib.Path(self.path).joinpath(f"model_{epoch}.pt")
                torch.save(self.model, model_path)
                print(f"Model saved at epoch {epoch}.")

            if epoch % self.test_freq == 0:
                with torch.no_grad():
                    self.model.eval()

                    data = self.test_data["data"].to(self.device)
                    init_state = self.test_data["init_state"].to(self.device)
                    targets = self.test_data["targets"].to(self.device)

                    if self.compute_all_metrics:
                        aux = self.test_data

                    _, _, outputs = self.model(data, init_state=init_state)

                    _, test_metric = self.model.task.compute_metrics(outputs, targets, aux)
                    self.test_metrics[epoch] = test_metric.copy()

                    print(f"Epoch {epoch} (test):")
                    for k, v in test_metric.items():
                        print(f"  - {k} = {v}.")

        train_metrics_path = pathlib.Path(self.path).joinpath("train_metrics.json")
        json.dump(self.train_metrics, open(train_metrics_path, "w"), indent=4)

        test_metrics_path = pathlib.Path(self.path).joinpath("test_metrics.json")
        json.dump(self.test_metrics, open(test_metrics_path, "w"), indent=4)
