import torch
from torch import autograd, nn, optim
from typing import Callable
import importlib
import numpy as np
import random
import os

NUM_EPOCHS = 32_000
WEIGHT_FACTOR = 1000
LR = 1e-3


class TransformedFunc2D(nn.Module):
    def __init__(
        self,
        num_components: int,
        orientation_preserving: bool,
        func: Callable,
        num_outputs: int,
        apply_rotation: bool,
    ):
        """
        Apply multiple sets of transformations to 2D harmonic function
        and combine them to a single output of desired dimension.

        Args:
            num_components (int): number of transformations that each function undergoes
            orientation_preserving (bool): determines whether flipping is applied amongst other transformations
            func (callable): target harmonic function
            num_outputs (int): output dimension
            apply_rotation (bool): determines whether rotations are applied to the target function
        """
        super().__init__()
        self.num_components = num_components
        self.angles = nn.Parameter(2 * torch.pi * torch.rand(num_components))
        self.shifts = nn.Parameter(torch.zeros(num_components, 2))
        self.input_scales = nn.Parameter(torch.ones(num_components))
        self.output_weights = nn.Parameter(torch.zeros(num_components))
        self.output_shifts = nn.Parameter(torch.zeros(num_components, num_outputs))
        self.orientation_preserving = orientation_preserving
        self.func = func
        self.apply_rotation = apply_rotation

    def forward(self, x):
        if self.apply_rotation:
            matrices = self.rotation_matrices()
            x_transform = torch.matmul(matrices, x.T).permute(
                [2, 0, 1]
            ) + self.shifts.unsqueeze(0)
        else:
            x_transform = x.unsqueeze(1).repeat(
                1, self.num_components, 1
            ) + self.shifts.unsqueeze(0)
        x_transform = x_transform * self.input_scales.unsqueeze(0).unsqueeze(2)
        batch_size = x.shape[0]
        x_transform = x_transform.reshape(batch_size * self.num_components, 2)
        func_outputs = self.func(x_transform).reshape(
            batch_size, self.num_components, -1
        )
        output_weights = self.output_weights.reshape(1, self.num_components, 1)
        output_shifts = self.output_shifts.unsqueeze(0)
        return (output_weights * func_outputs + output_shifts).sum(dim=1)

    def rotation_matrices(self):
        if self.orientation_preserving:
            return self.rotation_matrices_orientation_preserving()
        return self.rotation_matrices_flipping()

    def rotation_matrices_orientation_preserving(self):
        sin = torch.sin(self.angles)
        cos = torch.cos(self.angles)
        out = torch.zeros(self.num_components, 2, 2)
        out[:, 0, 0] = cos
        out[:, 1, 1] = cos
        out[:, 1, 0] = -sin
        out[:, 0, 1] = sin
        return out

    def rotation_matrices_flipping(self):
        sin = torch.sin(self.angles)
        cos = torch.cos(self.angles)
        out = torch.zeros(self.num_components, 2, 2)
        out[:, 0, 0] = -cos
        out[:, 1, 1] = cos
        out[:, 1, 0] = sin
        out[:, 0, 1] = sin
        return out


def set_rng_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)


class DivergenceFreeNetwork(nn.Module):
    def __init__(self, num_components_per_func):
        super().__init__()
        self.transformed_funcs = self.setup_transformed_funcs(
            [
                self.func1,
                self.func2,
                self.func3,
                self.func4,
                self.func5,
                self.func6,
            ],
            num_components_per_func,
        )

    def forward(self, x):
        output = 0.0
        for func in self.transformed_funcs:
            output = output + func(x)
        return output

    @staticmethod
    def setup_transformed_funcs(functions, num_components):
        output = []
        for func in functions:
            output.append(
                TransformedFunc2D(
                    num_components=num_components,
                    orientation_preserving=False,
                    func=func,
                    num_outputs=2,
                    apply_rotation=False,
                )
            )
        return nn.ModuleList(output)

    @staticmethod
    def func1(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([y, x], dim=1)

    @staticmethod
    def func2(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([x, -y], dim=1)

    @staticmethod
    def func3(x):
        x, y = x[:, 0], x[:, 1]
        zeros = torch.zeros_like(x)
        return torch.stack([zeros, x], dim=1)

    @staticmethod
    def func4(x):
        x, y = x[:, 0], x[:, 1]
        zeros = torch.zeros_like(x)
        return torch.stack([y, zeros], dim=1)

    @staticmethod
    def func5(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([torch.cos(x + y), -torch.cos(x + y)], dim=1)

    @staticmethod
    def func6(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([torch.exp(x + y), -torch.exp(x + y)], dim=1)

    @staticmethod
    def func7(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([x * torch.cos(x * y), -y * torch.cos(x * y)], dim=1)

    @staticmethod
    def func8(x):
        x, y = x[:, 0], x[:, 1]
        return torch.stack([torch.exp(x + y), -torch.exp(x + y)], dim=1)


class HeatNetwork(nn.Module):
    def __init__(self, diffusion_coefficient, num_components):
        super().__init__()
        self.x_shift = nn.Parameter(torch.rand(1, num_components))
        self.x_scale = nn.Parameter(torch.rand(1, num_components))
        self.y_shift = nn.Parameter(torch.rand(1, num_components))
        self.y_scale = nn.Parameter(torch.rand(1, num_components))
        self.t_shift = nn.Parameter(torch.rand(1, num_components))
        self.out_shift = nn.Parameter(torch.rand(1, num_components))
        self.weights = nn.Parameter(torch.rand(1, num_components))
        self.D = diffusion_coefficient

    def forward(self, x):
        x, y, t = x[:, 0].unsqueeze(1), x[:, 1].unsqueeze(1), x[:, 2].unsqueeze(1)
        x_component = torch.sin(self.x_scale * x + self.x_shift)
        y_component = torch.sin(self.y_scale * y + self.y_shift)
        t_scale = self.D * (self.x_scale**2 + self.y_scale**2)
        t_component = torch.exp(-t_scale * t + self.t_shift)
        return (
            (self.weights * x_component * y_component * t_component + self.out_shift)
            .sum(dim=1)
            .unsqueeze(1)
        )


def train(module, model_id):
    loss_fn, net = module.get_loss_function_and_network()
    opt = optim.Adam(net.parameters(), lr=LR)
    for epoch in range(NUM_EPOCHS):
        loss = loss_fn(net)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if epoch % 100 == 0:
            print(f"Model ID: {model_id}. Epoch: {epoch}. Loss: {loss:.3e}")
    torch.save(net.state_dict(), f"./weights/{model_id}.pt")


def evaluate(experiment_name, method):
    benchmark_mod_path = f"src.{experiment_name}.benchmark"
    benchmark_mod = importlib.import_module(benchmark_mod_path)
    set_rng_seed(42)
    benchmark = benchmark_mod.Benchmark()
    net_path = f"src.{experiment_name}.{method}"
    net_mod = importlib.import_module(net_path)
    paths = os.listdir("./weights/")
    paths = [path for path in paths if path.startswith(f"{experiment_name}_{method}_")]
    paths = [f"./weights/{path}" for path in paths]
    rmses = []
    for path in paths:
        net = net_mod.Net()
        net.load_state_dict(torch.load(path))
        rmses.append(benchmark.evaluate_network(net))
    print(f"Method: {method}. Mean: {np.mean(rmses)}. Std: {np.std(rmses)}")


def evaluate_experiment(experiment_name):
    paths = os.listdir("./weights/")
    paths = [path for path in paths if path.startswith(f"{experiment_name}_")]
    methods = set([path.split("_")[1] for path in paths])
    for method in methods:
        evaluate(experiment_name, method)


def d(y, x):
    return autograd.grad(y.sum(), x, retain_graph=True, create_graph=True)[0]


def laplacian(net):
    x = torch.rand(32, requires_grad=True)
    y = torch.rand(32, requires_grad=True)
    phi = net(torch.stack([x, y], dim=1))
    return d(d(phi, x), x) + d(d(phi, y), y)


def divergence(net):
    x = torch.rand(32, requires_grad=True)
    y = torch.rand(32, requires_grad=True)
    phi = net(torch.stack([x, y], dim=1))
    return d(phi[:, 0], x) + d(phi[:, 1], y)


def heat(net, D):
    x = torch.rand(32, requires_grad=True)
    y = torch.rand(32, requires_grad=True)
    t = torch.rand(32, requires_grad=True)
    u = net(torch.stack([x, y, t], dim=1))
    rhs = D * (d(d(u, x), x) + d(d(u, y), y))
    lhs = d(u, t)
    return rhs - lhs
