import torch
import torch.nn as nn
import torch.optim as optim
import os
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random

from utils import general
from networks import networks, eff_kan, cheb_kan
from objectives import ncc
from objectives import regularizers

def setup_deterministic(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.use_deterministic_algorithms(True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False

class ImplicitRegistrator2d:
    """This is a class for registrating implicitly represented images."""

    def __call__(
        self, coordinate_tensor=None, output_shape=(28, 28), dimension=0, slice_pos=0
    ):
        """Return the image-values for the given input-coordinates."""

        # Use standard coordinate tensor if none is given
        if coordinate_tensor is None:
            coordinate_tensor = self.make_coordinate_slice(
                output_shape, dimension, slice_pos
            )

        batch_size = 20000
        output = torch.zeros_like(coordinate_tensor).to('cuda')
        i = 0
        N = coordinate_tensor.shape[0]
        while i < N:
            with torch.no_grad():
                curr_coords = coordinate_tensor[i:min(N, i + batch_size)]
                output[i:min(N, i + batch_size)] = self.network(curr_coords)
                i += batch_size

        # Shift coordinates by 1/n * v
        coord_temp = torch.add(output, coordinate_tensor)
        transformed_image = self.transform_no_add(coord_temp)
        return (
            transformed_image.cpu()
            .detach()
            .numpy()
            .reshape(output_shape[0], output_shape[1])
        )

    def __init__(self, moving_image, fixed_image, **kwargs):
        """Initialize the learning model."""

        # Set all default arguments in a dict: self.args
        self.set_default_arguments()

        # Check if all kwargs keys are valid (this checks for typos)
        assert all(kwarg in self.args.keys() for kwarg in kwargs)

        self.mask = kwargs['mask'] if 'mask' in kwargs else self.args['mask']
        # Parse important argument from kwargs
        self.epochs = kwargs["epochs"] if "epochs" in kwargs else self.args["epochs"]
        self.log_interval = (
            kwargs["log_interval"]
            if "log_interval" in kwargs
            else self.args["log_interval"]
        )
        self.gpu = kwargs["gpu"] if "gpu" in kwargs else self.args["gpu"]
        self.lr = kwargs["lr"] if "lr" in kwargs else self.args["lr"]
        self.momentum = (
            kwargs["momentum"] if "momentum" in kwargs else self.args["momentum"]
        )
        self.optimizer_arg = (
            kwargs["optimizer"] if "optimizer" in kwargs else self.args["optimizer"]
        )
        self.loss_function_arg = (
            kwargs["loss_function"]
            if "loss_function" in kwargs
            else self.args["loss_function"]
        )
        self.layers = kwargs["layers"] if "layers" in kwargs else self.args["layers"]
        self.kan_layers = kwargs["kan_layers"] if "kan_layers" in kwargs else self.args["kan_layers"]
        self.weight_init = (
            kwargs["weight_init"]
            if "weight_init" in kwargs
            else self.args["weight_init"]
        )
        self.omega = kwargs["omega"] if "omega" in kwargs else self.args["omega"]
        self.save_folder = (
            kwargs["save_folder"]
            if "save_folder" in kwargs
            else self.args["save_folder"]
        )

        # Parse other arguments from kwargs
        self.verbose = (
            kwargs["verbose"] if "verbose" in kwargs else self.args["verbose"]
        )

        self.scheduler_type = (
            kwargs["scheduler_type"] if "scheduler_type" in kwargs else self.args["scheduler_type"]
        )
        # Make folder for output
        if not self.save_folder == "" and not os.path.isdir(self.save_folder):
            os.mkdir(self.save_folder)

        # Add slash to divide folder and filename
        self.save_folder += "/"

        # Make loss list to save losses
        self.loss_list = [0 for _ in range(self.epochs)]
        self.data_loss_list = [0 for _ in range(self.epochs)]

        self.seed = kwargs['seed'] if 'seed' in kwargs else self.args['seed']
        # Set seed
        self.gen_cpu = torch.Generator(device='cpu').manual_seed(self.seed)
        self.gen_cuda = torch.Generator(device='cuda').manual_seed(self.seed)
        setup_deterministic(self.seed)
        
        # Load network
        self.network_from_file = (
            kwargs["network"] if "network" in kwargs else self.args["network"]
        )
        self.network_type = (
            kwargs["network_type"]
            if "network_type" in kwargs
            else self.args["network_type"]
        )
        if self.network_from_file is None:
            if self.network_type.lower() == 'kan':
                self.network = cheb_kan.ChebyKAN(layers_hidden=[2, 70, 70, 2], degree=24)
            elif self.network_type == "MLP":
                self.network = networks.MLP(self.layers)
            else:
                self.network = networks.Siren(self.layers, self.weight_init, self.omega, gen=self.gen_cpu)
            if self.verbose:
                print(
                    "Network contains {} trainable parameters.".format(
                        general.count_parameters(self.network)
                    )
                )
        else:
            self.network = torch.load(self.network_from_file)
            if self.gpu:
                self.network.cuda()
        self.network = torch.jit.script(self.network)
        # Choose the optimizer
        if self.optimizer_arg.lower() == "sgd":
            self.optimizer = optim.SGD(
                self.network.parameters(), lr=self.lr, momentum=self.momentum
            )

        elif self.optimizer_arg.lower() == "adam":
            self.optimizer = optim.Adam(self.network.parameters(), lr=self.lr)

        elif self.optimizer_arg.lower() == "adadelta":
            self.optimizer = optim.Adadelta(self.network.parameters(), lr=self.lr)

        else:
            self.optimizer = optim.SGD(
                self.network.parameters(), lr=self.lr, momentum=self.momentum
            )
            print(
                "WARNING: "
                + str(self.optimizer_arg)
                + " not recognized as optimizer, picked SGD instead"
            )

        if self.scheduler_type == 'linear':
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda x: 0.1**min(x/self.epochs, 1))
        else:
            self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=1)
        # Choose the loss function
        if self.loss_function_arg.lower() == "mse":
            self.criterion = nn.MSELoss()

        elif self.loss_function_arg.lower() == "l1":
            self.criterion = nn.L1Loss()

        elif self.loss_function_arg.lower() == "ncc":
            self.criterion = ncc.NCC()

        elif self.loss_function_arg.lower() == "smoothl1":
            self.criterion = nn.SmoothL1Loss(beta=0.2)

        elif self.loss_function_arg.lower() == "huber":
            self.criterion = nn.HuberLoss()

        else:
            self.criterion = nn.MSELoss()
            print(
                "WARNING: "
                + str(self.loss_function_arg)
                + " not recognized as loss function, picked MSE instead"
            )

        # Move variables to GPU
        if self.gpu:
            self.network.cuda()

        # Parse arguments from kwargs
        self.mask = kwargs["mask"] if "mask" in kwargs else self.args["mask"]

        # Parse regularization kwargs
        self.jacobian_regularization = (
            kwargs["jacobian_regularization"]
            if "jacobian_regularization" in kwargs
            else self.args["jacobian_regularization"]
        )
        self.alpha_jacobian = (
            kwargs["alpha_jacobian"]
            if "alpha_jacobian" in kwargs
            else self.args["alpha_jacobian"]
        )

        self.hyper_regularization = (
            kwargs["hyper_regularization"]
            if "hyper_regularization" in kwargs
            else self.args["hyper_regularization"]
        )
        self.alpha_hyper = (
            kwargs["alpha_hyper"]
            if "alpha_hyper" in kwargs
            else self.args["alpha_hyper"]
        )

        self.bending_regularization = (
            kwargs["bending_regularization"]
            if "bending_regularization" in kwargs
            else self.args["bending_regularization"]
        )
        self.alpha_bending = (
            kwargs["alpha_bending"]
            if "alpha_bending" in kwargs
            else self.args["alpha_bending"]
        )

        self.diffusion_regularization = (
            kwargs["diffusion_regularization"]
            if "diffusion_regularization" in kwargs
            else self.args["diffusion_regularization"]
        )
        self.alpha_diffusion = (
            kwargs["alpha_diffusion"]
            if "alpha_diffusion" in kwargs
            else self.args["alpha_diffusion"]
        )

        # Parse arguments from kwargs
        self.image_shape = (
            kwargs["image_shape"]
            if "image_shape" in kwargs
            else self.args["image_shape"]
        )
        self.batch_size = (
            kwargs["batch_size"] if "batch_size" in kwargs else self.args["batch_size"]
        )

        # Initialization
        self.moving_image = moving_image
        self.fixed_image = fixed_image

        # self.possible_coordinate_tensor = general.make_masked_coordinate_tensor_2d(
        #     self.mask, self.fixed_image.shape
        # )
        self.strides = [16, 8, 4, 2]
        self.milestones = [2500]
        self.epochs = self.milestones[-1]
        self.stride_ind = 0
        self.possible_coordinate_tensor = general.make_masked_coordinate_tensor_2d(
            self.mask, self.fixed_image.shape, self.strides[0]
        )

        if self.gpu:
            self.moving_image = self.moving_image.cuda()
            self.fixed_image = self.fixed_image.cuda()

    def cuda(self):
        """Move the model to the GPU."""

        # Standard variables
        self.network.cuda()

        # Variables specific to this class
        self.moving_image.cuda()
        self.fixed_image.cuda()
        
    def set_default_arguments(self):
        """Set default arguments."""

        # Inherit default arguments from standard learning model
        self.args = {}

        # Define the value of arguments
        self.args["mask"] = None
        self.args["mask_2"] = None

        self.args["method"] = 1

        self.args["lr"] = 0.00001
        self.args["batch_size"] = 10000
        self.args["layers"] = [2, 256, 256, 256, 2]
        self.args["kan_layers"] = [2, 256, 256, 256, 2]
        self.args["velocity_steps"] = 1

        # Define argument defaults specific to this class
        self.args["output_regularization"] = False
        self.args["alpha_output"] = 0.2
        self.args["reg_norm_output"] = 1

        self.args["jacobian_regularization"] = False
        self.args["alpha_jacobian"] = 0.05

        self.args["hyper_regularization"] = False
        self.args["alpha_hyper"] = 0.25

        self.args["bending_regularization"] = False
        self.args["alpha_bending"] = 10.0

        self.args["diffusion_regularization"] = False
        self.args["alpha_diffusion"] = 0.01

        self.args["image_shape"] = (200, 200)

        self.args["network"] = None

        self.args["epochs"] = 2500
        self.args["log_interval"] = self.args["epochs"] // 4
        self.args["verbose"] = True
        self.args["save_folder"] = "output"

        self.args["network_type"] = "MLP"

        self.args["gpu"] = torch.cuda.is_available()
        self.args["optimizer"] = "Adam"
        self.args["loss_function"] = "ncc"
        self.args["momentum"] = 0.5

        self.args["positional_encoding"] = False
        self.args["weight_init"] = True
        self.args["omega"] = 32

        self.args["seed"] = 1
        self.args['scheduler_type'] = None
        self.args['mask'] = None

    def training_iteration(self, epoch):
        """Perform one iteration of training."""
        if epoch + 1 < self.epochs and epoch == self.milestones[self.stride_ind]:
            self.stride_ind += 1
            self.possible_coordinate_tensor = general.make_masked_coordinate_tensor_2d(
            self.mask, self.fixed_image.shape, self.strides[self.stride_ind]
        )

        # Reset the gradient
        self.network.train()
        loss = 0
        indices = torch.randperm(
            self.possible_coordinate_tensor.shape[0], generator=self.gen_cuda, device='cuda'
        )[: self.batch_size]
        coordinate_tensor = self.possible_coordinate_tensor[indices, :]
        coordinate_tensor = coordinate_tensor.requires_grad_(True)

        output = self.network(coordinate_tensor)
        coord_temp = torch.add(output, coordinate_tensor)
        output = coord_temp

        transformed_image = self.transform_no_add(coord_temp)
        fixed_image = general.fast_trilinear_interpolation_2d(
            self.fixed_image,
            coordinate_tensor[:, 0],
            coordinate_tensor[:, 1]
        )

        # Compute the loss
        loss += self.criterion(transformed_image, fixed_image)
        # Store the value of the data loss
        if self.verbose:
            self.data_loss_list[epoch] = loss.detach().cpu().numpy()

        # Relativation of output
        output_rel = torch.subtract(output, coordinate_tensor)

        # Regularization
        if self.jacobian_regularization:
            loss += self.alpha_jacobian * regularizers.compute_jacobian_loss(
                coordinate_tensor, output_rel, batch_size=self.batch_size
            )
        if self.hyper_regularization:
            loss += self.alpha_hyper * regularizers.compute_hyper_elastic_loss(
                coordinate_tensor, output_rel, batch_size=self.batch_size
            )
        if self.bending_regularization:
            loss += self.alpha_bending * regularizers.compute_bending_energy_2d(
                coordinate_tensor, output_rel, batch_size=self.batch_size
            )
        
        if self.diffusion_regularization:
            loss += self.alpha_diffusion * regularizers.compute_diffusion_loss_2d(
                coordinate_tensor, output_rel
            )
        # Perform the backpropagation and update the parameters accordingly

        self.optimizer.zero_grad()
        loss.backward()
        # if epoch % 100 == 0:
        #     print(self.network.layers[1].base_weight.grad[0], output[:10])
        self.optimizer.step()
        self.scheduler.step()

        # Store the value of the total loss
        if self.verbose:
            self.loss_list[epoch] = loss.item()

    def transform(
        self, transformation, coordinate_tensor=None, moving_image=None, reshape=False
    ):
        """Transform moving image given a transformation."""

        # If no specific coordinate tensor is given use the standard one of 28x28
        if coordinate_tensor is None:
            coordinate_tensor = self.coordinate_tensor

        # If no moving image is given use the standard one
        if moving_image is None:
            moving_image = self.moving_image

        # From relative to absolute
        transformation = torch.add(transformation, coordinate_tensor)
        return general.fast_trilinear_interpolation_2d(
            moving_image,
            transformation[:, 0],
            transformation[:, 1]
        )

    def transform_no_add(self, transformation, moving_image=None, reshape=False):
        """Transform moving image given a transformation."""

        # If no moving image is given use the standard one
        if moving_image is None:
            moving_image = self.moving_image
        # print('GET MOVING')
        return general.fast_trilinear_interpolation_2d(
            moving_image,
            transformation[:, 0],
            transformation[:, 1]
        )

    def fit(self, epochs=None):
        """Train the network."""

        # Determine epochs
        if epochs is None:
            epochs = self.epochs

        # Extend lost_list if necessary
        if not len(self.loss_list) == epochs:
            self.loss_list = [0 for _ in range(epochs)]
            self.data_loss_list = [0 for _ in range(epochs)]

        # Perform training iterations
        for i in tqdm.tqdm(range(epochs)):
            self.training_iteration(i)
        
        self.moved_image = self.__call__(coordinate_tensor=general.make_coordinate_tensor_2d(self.fixed_image.shape),
                                            output_shape=self.fixed_image.shape)
        torch.save(self.network.state_dict(), f'anhir_net_{self.network_type.lower()}.pth')
        if self.verbose:
            plt.plot(self.loss_list)
            plt.savefig(f'loss_{self.network_type}.png')
            plt.close()
