import torch
import wandb
import plot

from learn import util

import torch.optim as optim

from solvers.Solver import Solver
from models.Generator import Generator, SuperResGenerator

from tqdm import tqdm


def torch_train(model, dataset, epochs, lr, save=True):
    train_set, test_set, train_labels, test_labels = util.split_test_train(dataset.data, dataset.labels, test_ratio=0.1)

    tokens = train_set
    y = train_labels
    optimiser = optim.Adam(model.parameters(), lr=lr)

    real_image = wandb.Image(plot.to_image(y[0]), caption=f'Real Image')
    wandb.watch(model, log_freq=100)
    wandb.log({'real_image': real_image})

    train_step, test_step = util.define_metrics()

    loss_curves = []
    with tqdm(total=epochs, unit='epochs') as bar:
        for i in range(epochs):
            optimiser.zero_grad()
            x = model(tokens)
            loss = model.loss_function(x, y)
            loss.backward()
            loss_curves.append([loss.item()])   # TODO: Fix the jank that requires this to be done
            optimiser.step()
            wandb.log({'train/loss_g': loss.item(), 'train/step': train_step})
            train_step += 1
            bar.update(1)
            bar.set_postfix(loss=loss.item())

            if i % 2000 == 0:
                util.save_generated_image(model, tokens[0], y[0], i)
                if save:  # Save model every epoch in case training fails
                    print('Saving model...')
                    model.save(i, optimiser)
                with torch.no_grad():
                    x_test = model(test_set)
                    loss_test = model.loss_function(x_test, test_labels)
                    l1_loss = torch.nn.functional.l1_loss(x_test, test_labels).item()
                    wandb.log({'test/loss': loss_test.item(), 'test/step': test_step, 'test/l1_loss': l1_loss})
                    test_step += 1
    util.get_res_err(x, y)
    return loss_curves

def gan_train(model, dataset, epochs, lr, epsilon=1e-8, save=True):
    tokens = dataset.data
    y = dataset.labels

    test_sample, test_label = tokens[0], y[0]
    tokens.pop(0), y.pop(0)  # Remove first sample for the "test set"

    v = torch.randn_like(test_label, device=dataset.device, requires_grad=True)

    wandb.watch(model, log_freq=100)

    real_image = wandb.Image(plot.to_image(test_label), caption=f'Real Image')
    wandb.log({'real_image': real_image})

    generator_optimiser = optim.Adam(model.parameters(), lr=lr)
    v_optimiser = optim.Adam([v], lr=lr)

    with tqdm(total=epochs, unit='steps') as bar:
        for i in range(epochs):
            x = model(tokens)
            diff = x - y

            v_optimiser.zero_grad()
            v.data = v.data / (v.data.norm() + epsilon)  # Quotient out the vector norm, prevent division by zero

            loss_v = -torch.max(v * diff).mean()
            loss_v.backward(retain_graph=True)
            v_optimiser.step()

            generator_optimiser.zero_grad()
            loss_g = torch.mean(torch.abs(v * diff))
            loss_g.backward()  # TODO: Can use Tikhonov regularisation here? Maybe not necessary
            generator_optimiser.step()

            bar.update(1)
            bar.set_postfix(loss_g=loss_g.item())

            wandb.log({'loss_g': loss_g.item(), 'loss_v': loss_v.item()})

            util.save_generated_image(model, test_sample, test_label, i)
        if save:
            print('Saving model...')
            model.save(epochs, generator_optimiser)
    print(f"Avg resid: {torch.abs(diff).mean().item()}")

def adversarial_train(model, dataset, solver, experiment, save=True):
    # This method only works with the heat equation, due to limitations with backprop
    print("Beginning training...")
    loss_curves = []

    epochs = experiment.epochs
    optimiser = optim.Adam(model.parameters(), lr=experiment.lr)

    u = dataset.initial

    with tqdm(total=epochs, unit='epochs') as bar:
        for i in range(epochs):
            optimiser.zero_grad()

            if i % 5 == 0: # Introduce a new random condition every 5 iterations
                u = solver.generate_initial_conditions(solver.seed + i)
                u = Solver.normalise(u)
                u = torch.tensor(u, requires_grad=True, dtype=torch.half).to(dataset.device)
                u.retain_grad()

            data = dataset.generate_data(u)
            tokens, y = dataset.tokens_from_data(data)

            x = model(tokens)
            loss = model.loss_function(x, y)
            loss.backward()

            # Perform adversarial update
            # TODO: Use an optimiser here instead?
            u_grad = u.grad.data  # Get the gradient of u
            u = u + experiment.epsilon * u_grad.sign()  # Adversarial step
            u = Solver.normalise(u)  # Ensure values are within valid range
            u = u.detach()  # Detach u from the computation graph
            u.requires_grad_()  # Re-enable gradient tracking for the new u

            optimiser.step()

            loss_curves.append([loss.item()])
            bar.update(1)
            op_loss = model.loss_function(dataset.op, model.model.weight)
            bar.set_postfix(loss=loss.item(), op_loss=op_loss.item())

    util.get_res_err(x, y)
    if save:
        print('Saving model...')
        model.save(epochs, optimiser)
    return loss_curves


def quasi_gan_train(dataset, epochs, lr, batch_size, seed, epsilon=1e-8, tokens=False):
    v = torch.randn_like(dataset.labels[0], device=dataset.device, requires_grad=True)
    if tokens:
        generator = Generator(1, 1).to(dataset.device)
    else:
        generator = SuperResGenerator(dataset.history_length, 1, 1).to(dataset.device)
    wandb.watch(generator, log_freq=100)

    generator_optimiser = optim.Adam(generator.parameters(), lr=lr)
    v_optimiser = optim.Adam([v], lr=lr)

    train_loader, test_loader = util.ks_setup(dataset, batch_size)
    test_sample, test_label = util.save_real_image(test_loader, 'GAN')

    train_step, test_step = util.define_metrics()

    with tqdm(total=epochs * len(train_loader), unit='steps') as bar:
        for i in range(epochs):
            for batch_data, batch_labels in train_loader:
                fake_images = generator(batch_data)
                diff = fake_images - batch_labels

                v_optimiser.zero_grad()
                v.data = v.data / (v.data.norm() + epsilon)  # Quotient out the vector norm, prevent division by zero

                loss_v = -torch.max(v * diff).mean()
                loss_v.backward(retain_graph=True)
                v_optimiser.step()

                generator_optimiser.zero_grad()
                loss_g = torch.mean(torch.abs(v * diff))
                loss_g.backward()   # TODO: Can use Tikhanov regularisation here? Maybe not necessary
                generator_optimiser.step()

                bar.update(1)
                bar.set_postfix(loss_g=loss_g.item())
                wandb.log({'train/loss_g': loss_g.item(), 'train/loss_v': loss_v.item(), 'train/step': train_step})
                train_step += 1
            with torch.no_grad():
                for batch_data, batch_labels in test_loader:
                    fake_images = generator(batch_data)
                    diff = fake_images - batch_labels
                    resid = torch.abs(diff).mean()
                    wandb.log({'test/loss': resid.item(), 'test/step': test_step})
                    test_step += 1
            generator.save_gan(i, seed, generator_optimiser, v_optimiser)
            util.save_generated_image(generator, test_sample, test_label, i)


def ar_train(dataset, epochs, lr, batch_size, epsilon=1e-8):
    model = Generator(1, 1).to(dataset.device)

    wandb.watch(model, log_freq=100)

    optimiser = optim.Adam(model.parameters(), lr=lr)

    train_loader, test_loader = util.ks_setup(dataset, batch_size)
    test_sample, test_label = util.save_real_image(test_loader, 'AR')

    with tqdm(total=epochs * len(train_loader), unit='steps') as bar:
        for i in range(epochs):
            for batch_data, batch_labels in train_loader:
                fake_images = model(batch_data)

                optimiser.zero_grad()
                loss = torch.nn.MSELoss()(fake_images, batch_labels)
                loss.backward()
                optimiser.step()

                bar.update(1)
                bar.set_postfix(loss=loss.item())
                wandb.log({'loss': loss.item()})
            model.save(i, optimiser)
            util.save_generated_image(model, test_sample, test_label, i)
