import math
import numpy as np
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib.pyplot as plt
from models_utils import Network

from tqdm import tqdm


class AEBNN(nn.Module):

    def __init__(self, network_specs, cond_size, weight_decay, lr, loss_type, device):
        super().__init__()

        self.loss_type = loss_type
        self.device = device
        self.weight_decay = weight_decay

        self.cond_size = cond_size

        self.f = Network(network_specs, loss_type, probabilistic=False)
        if cond_size > 0:
            self.theta_cond = nn.Parameter(torch.zeros(1, cond_size), requires_grad=True)
        else:
            self.theta_cond = None
        self.theta = nn.ParameterList([nn.Parameter(torch.zeros(t_size), requires_grad=True) for t_size in self.f.get_theta_shape()])
        self.f.init_params(self.theta)

        self.I = torch.eye(self.f.tot_params).to(self.device)

        self.optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)

        self.traj_theta = None
        self.encoder = None

    def forward(self, x, theta=None, c=None):
        if theta is None:
            theta = self.theta
        if c is None and self.cond_size > 0:
            c = self.theta_cond.repeat(x.shape[0], 1)
        else:
            c = c.repeat(x.shape[0], 1)
        try:
            return self.f(x, theta, c)
        except:
            return self.f(x, self.f.get_unflat_params(theta), c)

    def loss(self, y_pred, y, logs=False):
        L_evidence = self.f.loss_neglikelihood(y_pred, y)
        if not logs:
            return L_evidence
        losses = {'Evidence': L_evidence.detach().cpu().item()}
        return losses, L_evidence

    def batched_loss(self, x, y, theta):
        if self.cond_size > 0:
            return (1 / x.shape[0]) * self.f.loss_neglikelihood(self(x, c=theta), y) + self.weight_decay * ((theta) ** 2).mean()
        return (1/x.shape[0])*self.f.loss_neglikelihood(self(x, theta=theta), y) + self.weight_decay * ((theta) ** 2).mean()

    def get_dataset_loss(self, theta):

        get_batched_loss = torch.vmap(self.batched_loss, in_dims=(None, None, 0), randomness='same')
        losses = None
        for i, (x, y) in enumerate(self.loader):
            if i > 50:
                break
            batch_loss = get_batched_loss(x, y, theta)
            losses = batch_loss if losses is None else losses + batch_loss
        return losses

    def explore_fiber(self):

        theta = torch.squeeze(self.theta_cond) if self.cond_size > 0 else self.f.get_flat_params(self.theta)
        last_theta = theta.clone()

        global_drift = torch.stack([torch.randn_like(last_theta) for _ in range(self.n_traj)], 0)

        theta_fiber = []

        last_theta = torch.unsqueeze(theta, 0).repeat(self.n_traj, 1).clone()
        for _ in tqdm(range(self.T)):
            drift = torch.randn_like(last_theta) if not self.use_brownian else global_drift
            norm_drift = (drift / drift.norm(dim=1, keepdim=True)) * self.alpha

            last_theta = last_theta + norm_drift

            for _ in range(self.n_steps):

                J = torch.autograd.functional.jacobian(self.get_dataset_loss, (last_theta)).sum(0)
                last_theta = last_theta - self.inner_lr * J

            theta_fiber.append(last_theta.detach())

        traj_theta = torch.stack(theta_fiber, 1)

        return traj_theta

    def adapt_theta(self, theta_hat, epochs, lr=0.001):
        for _ in tqdm(range(epochs)):
            J = torch.autograd.functional.jacobian(self.get_dataset_loss, (theta_hat)).sum(0)
            theta_hat = theta_hat - lr * J
        return theta_hat


    def train_metric(self, traj_theta):

        n, T, d = traj_theta.shape[0], traj_theta.shape[1], traj_theta.shape[-1]
        self.encoder = nn.Sequential(
            nn.Linear(d, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.k)).to(self.device)
        self.decoder = nn.Sequential(
            nn.Linear(self.k, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, d)).to(self.device)
        opt = torch.optim.Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=0.001)

        theta_map = torch.unsqueeze(self.f.get_flat_params(self.theta), 0)

        for _ in tqdm(range(self.epochs)):
            trj_i = np.random.randint(0, n, self.batch_size)
            t = np.random.randint(0, T-1, self.batch_size)

            theta_t = traj_theta[trj_i, t]
            theta_t1 = traj_theta[trj_i, t+1]
            rnd_idx = np.arange(self.batch_size)
            np.random.shuffle(rnd_idx)
            theta_rnd = theta_t1[rnd_idx]

            z_0 = self.encoder(theta_map)
            z_t = self.encoder(theta_t)
            z_t1 = self.encoder(theta_t1)
            z_rnd = self.encoder(theta_rnd)

            loss_0 = torch.mean(torch.linalg.norm(z_0, dim=-1) ** 2)
            dist_pos = torch.clamp(torch.linalg.norm(z_t - z_t1, dim=-1), min=(1 / T))
            loss_pos = torch.mean(dist_pos)
            loss_neg = - torch.mean(torch.log(torch.linalg.norm(z_t - z_rnd, dim=-1)/(1 / T) + 1e-6))

            theta_t_hat = self.decoder(z_t)
            theta_t1_hat = self.decoder(z_t1)

            loss_dec = torch.mean((theta_t - theta_t_hat) ** 2) + torch.mean((theta_t1 - theta_t1_hat) ** 2)

            loss = self.pos_lambda * loss_pos + self.neg_lambda * loss_neg + self.dec_lambda * loss_dec + loss_0

            opt.zero_grad()
            loss.backward()
            opt.step()

    def fit_svd(self, x, k):

        mu = x.mean(0)
        norm_x = x - mu
        U, S, Vt = torch.linalg.svd(norm_x, full_matrices=True)
        top_k_singular_vectors = Vt[:k, :]

        z_i = norm_x @ top_k_singular_vectors.T

        z_mu = z_i.mean(dim=0)
        z_sigma = torch.cov(z_i.T)

        num_samples = 100
        z_posterior = MultivariateNormal(z_mu, z_sigma).sample((num_samples,))

        theta_posterior = z_posterior @ top_k_singular_vectors + mu

        return theta_posterior

    def set_global_variables(self, variables):

        self.alpha = variables['alpha']
        self.T = variables['T']
        self.n_steps = variables['n_steps']
        self.n_traj = variables['n_traj']
        self.use_brownian = variables['use_brownian']
        self.inner_lr = variables['inner_lr']

        self.k = variables['k']
        self.batch_size = variables['batch_size']
        self.epochs = variables['epochs']
        self.pos_lambda = variables['pos_lambda']
        self.neg_lambda = variables['neg_lambda']
        self.dec_lambda = variables['dec_lambda']


    def fit_posterior(self, loader):

        self.loader = loader

        print('EXPLORE FIBER')
        self.traj_theta = self.explore_fiber()
        print()
        print('TRAIN Z')
        self.train_metric(self.traj_theta)
        print()


    def posterior(self, all_x, loader):

        self.loader = loader

        traj_theta = self.traj_theta if self.traj_theta is not None else self.explore_fiber()
        theta_map = torch.unsqueeze(self.f.get_flat_params(self.theta), 0)

        if self.encoder is None:
            self.train_metric(traj_theta)

        theta_vertices = traj_theta[:, -1]
        z_vertices = self.encoder(theta_vertices)
        z_0 = self.encoder(theta_map)
        num_samples = 30
        sampled_vertices = np.random.randint(0, theta_vertices.shape[0], num_samples)
        sampled_t = torch.rand(num_samples).to(self.device).view(-1, 1)
        z_posterior = z_0 + (z_vertices[sampled_vertices] - z_0) * sampled_t

        theta_posterior = self.decoder(z_posterior)
        py = torch.stack([self(all_x, theta_sample) for theta_sample in theta_posterior], 0)

        y_map = self(all_x, self.theta)
        y_mu = py.mean(0)
        y_std = py.std(0)

        return y_map, y_mu, y_std, py

    def posterior_naive(self, all_x, loader):

        self.loader = loader

        traj_theta = self.traj_theta
        n, T = traj_theta.shape[0], traj_theta.shape[1]

        all_theta = traj_theta.reshape([n * T, -1])
        theta_posterior = self.fit_svd(all_theta, self.k)

        py = torch.stack([self(all_x, theta_sample) for theta_sample in theta_posterior], 0)

        y_map = self(all_x, self.theta)
        y_mu = py.mean(0)
        y_std = py.std(0)

        return y_map, y_mu, y_std, py


    def overall_posterior(self, all_x, loader):

        self.loader = loader

        traj_theta = self.explore_fiber()

        theta_posterior = traj_theta.reshape([-1, traj_theta.shape[-1]]) # [:, -10:]

        py = torch.stack([self(all_x, theta_sample) for theta_sample in theta_posterior], 0)

        return py











