import numpy as np

rng = np.random.default_rng()
from utilities import *


def arr_tp(arr): return arr.transpose(0, 2, 1)


def fro_sq(X): return np.sum(np.square(X))


def ith_stack_op(A, B, i): return A[i] @ B[i].transpose()


# def mat_stack_op(A, B): return np.matmul(A, B.transpose(0, 2, 1))


# def mat_stack_op_sum(A, B): return np.sum(np.matmul(A, B.transpose(0, 2, 1)), axis=0)


class InfDataLearner:
    def __init__(self, d, k, T, eta=0.1, num_iters=10000, A_star=None, U_star=None, V_star=None, rand_inits=True):
        self.dim = d
        self.k = k
        self.T = T

        if A_star is None:
            self.A_star = np.random.normal(0, 1, (d, d))
        else:
            assert A_star.shape == (d, d)
            self.A_star = A_star
        if U_star is None:
            self.U_star = np.random.normal(0, 1, (T, d, k))
        else:
            assert U_star.shape == (T, d, k)
            self.U_star = U_star
        if V_star is None:
            self.V_star = np.random.normal(0, 1, (T, d, k))
        else:
            self.A_star, self.U_star, self.V_star = A_star, U_star, V_star

        if rand_inits:
            self.A = np.random.normal(0, 1, (d, d))
            self.U = np.random.normal(0, 1, (T, d, k))
            self.V = np.random.normal(0, 1, (T, d, k))
        else:
            self.A = np.zeros((d, d))
            self.U = np.zeros((T, d, k))
            self.V = np.zeros(T, d, k)

        self.eta = eta
        self.num_iters = num_iters

    def A_grad(self):
        return self.A - self.A_star + (1 / self.T) * (
            np.sum(mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star), axis=0))

    def U_grad(self):
        return (1 / self.T) * np.matmul(
            (self.A - self.A_star + mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star)), self.V)

    def V_grad(self):
        return (1 / self.T) * np.matmul(
            arr_tp(self.A - self.A_star + mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star)),
            self.U)

    def A_step(self):
        self.A -= self.eta * self.A_grad()

    def U_step(self):
        self.U -= self.eta * self.U_grad()

    def V_step(self):
        self.V -= self.eta * self.V_grad()

    def UV_step(self):
        self.U_step()
        self.V_step()

    def UV_update(self):
        for i in range(self.T):
            SVD = np.linalg.svd(self.A_star + ith_stack_op(self.U_star, self.V_star, i) - self.A)
            sigs = SVD.S[0:self.k]
            U = SVD.U[:, 0:self.k]
            V = SVD.Vh.T[:, 0:self.k]
            self.U[i, :, :] = sigs * U
            self.V[i, :, :] = V
        # SVD = np.linalg.svd(self.A_star + mat_stack_op(self.U_star, self.V_star) - self.A)
        # sigs = SVD.S[:, :self.k].reshape((self.T, 1, self.k))
        # Us = SVD.U[:, :, :self.k]
        # Vs = arr_tp(SVD.Vh)[:, :, :self.k]
        # self.U = sigs * Us
        # self.V = Vs

    def A_loss_fn(self):
        return fro_sq(self.A - self.A_star)

    def UV_loss_fn(self, t):
        return fro_sq(ith_stack_op(self.U, self.V, t) - ith_stack_op(self.U_star, self.V_star, t))

    def loss_fn(self):
        return np.sum(
            [fro_sq(self.A_star + ith_stack_op(self.U_star, self.V_star, i) - self.A - ith_stack_op(self.U, self.V, i))
             for i in range(self.T)]) / (2 * self.T)

    def optimize_GM(self, diagnostics=True):
        if diagnostics:
            self.A_loss = np.zeros(self.num_iters)
            self.UV_loss = np.zeros((self.num_iters, self.T))
            self.UV_err_var = np.zeros(self.num_iters)
            self.loss = np.zeros(self.num_iters)

        for i in range(self.num_iters):
            self.A_step()
            self.UV_update()

            if diagnostics:
                self.A_loss[i] = self.A_loss_fn()
                for t in range(self.T):
                    self.UV_loss[i, t] = self.UV_loss_fn(t)
                UV_err_avg = np.mean(mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star), axis=0)
                self.UV_err_var[i] = np.sum(
                    np.square(mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star) - UV_err_avg))
                self.loss[i] = self.loss_fn()

    def optimize_GG(self, diagnostics=True, num_inner=1):
        if diagnostics:
            self.A_loss = np.zeros(self.num_iters)
            self.UV_loss = np.zeros((self.num_iters, self.T))
            self.UV_err_var = np.zeros(self.num_iters)
            self.loss = np.zeros(self.num_iters)

        for i in range(self.num_iters):
            self.A_step()
            for j in range(num_inner):
                self.UV_step()

            if diagnostics:
                self.A_loss[i] = self.A_loss_fn()
                for t in range(self.T):
                    self.UV_loss[i, t] = self.UV_loss_fn(t)
                UV_err_avg = np.mean(mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star), axis=0)
                self.UV_err_var[i] = np.sum(
                    np.square(mat_stack_op(self.U, self.V) - mat_stack_op(self.U_star, self.V_star) - UV_err_avg))
                self.loss[i] = self.loss_fn()

    def check_UV_condition(self):
        diff = [np.linalg.norm(
            ith_stack_op(self.U_star, self.V_star, 0) - ith_stack_op(self.U, self.V, 0) - ith_stack_op(self.U_star,
                                                                                                       self.V_star,
                                                                                                       i) + ith_stack_op(
                self.U, self.V, i)) for i in range(self.T)]
        return np.all(np.isclose(diff, 0))


class MetaLearner:
    def __init__(self, dim, num_samples, low_rank_dim, num_tasks, tasks, sigma=0.1, inner_learning_rate=0.01,
                 outer_learning_rate=0.01, data=None, symmetric=True):
        self.dim = dim
        self.num_samples = num_samples
        self.low_rank_dim = low_rank_dim
        self.num_tasks = num_tasks
        self.sigma = sigma
        self.tasks = tasks
        self.outer_learning_rate = outer_learning_rate
        self.inner_learning_rate = inner_learning_rate
        self.symmetric = symmetric
        if data is None:
            self.gen_data()
        else:
            self.X, self.Y = data
        self.A = np.zeros((self.dim, self.dim))
        self.U = rng.normal(0, 0.1, (self.num_tasks, self.dim, self.low_rank_dim))
        self.V = rng.normal(0, 0.1, (self.num_tasks, self.dim, self.low_rank_dim))
        self.XXT = self.X @ arr_tp(self.X)

    def gen_data(self):
        noise = rng.normal(0, self.sigma ** 2, (self.num_tasks, self.dim, self.num_samples))
        self.X = rng.normal(0, 1, (self.num_tasks, self.dim, self.num_samples))
        self.Y = self.tasks @ self.X + noise

    def inner_gradient_step(self):
        if self.symmetric:
            self.U -= self.inner_learning_rate * self.inner_gradient_symmetric() / self.num_samples
        else:
            u_grad, v_grad = self.inner_gradient_asymmetric()
            self.U -= self.inner_learning_rate * u_grad / self.num_samples
            self.V -= self.inner_learning_rate * v_grad / self.num_samples

    def outer_gradient_step(self):
        self.A -= self.outer_learning_rate * self.outer_gradient() / (self.num_samples * self.num_tasks)

    def inner_gradient_symmetric(self):
        y_prime = (self.A + mat_stack_op(self.U, self.U)) @ self.X - self.Y
        return (self.X @ arr_tp(y_prime) + y_prime @ arr_tp(self.X)) @ self.U

    def inner_gradient_asymmetric(self):
        y_prime = (self.A + mat_stack_op(self.U, self.V)) @ self.X - self.Y
        return y_prime @ arr_tp(self.X) @ self.V, self.X @ arr_tp(y_prime) @ self.U

    def outer_gradient(self):
        return (((self.A + mat_stack_op(self.U, self.U)) @ self.X - self.Y) @ arr_tp(self.X)).sum(axis=0)

    def loss(self):
        if self.symmetric:
            return np.sum(((self.Y - (self.A + mmt(self.U)) @ self.X) ** 2)) / self.num_tasks
        else:
            return np.sum(((self.Y - (self.A + mat_stack_op(self.U, self.V)) @ self.X) ** 2)) / self.num_tasks

    def true_loss(self):
        return np.sum((self.tasks - self.A - mat_stack_op(self.U, self.U if self.symmetric else self.V)) ** 2) / 2

    def detach(self, tasks, num_samples, new_num_tasks=1, data=None, new_learning_rate=None):
        if new_learning_rate is None:
            new_learning_rate = self.inner_learning_rate
        newLearner = MetaLearner(self.dim,
                                 num_samples,
                                 self.low_rank_dim * (3 if self.num_tasks == 2 else 1),
                                 new_num_tasks,
                                 tasks,
                                 # 0,
                                 self.sigma,
                                 new_learning_rate,
                                 self.outer_learning_rate,
                                 data, symmetric=False)
        newLearner.A = self.A
        return newLearner
