import math

import torch
from einops import rearrange, repeat, pack
from torch import nn
from torch.nn import functional as F
from torch import distributions as D

from models import Model
from models.components import COMPONENT, Mlp
from utils import angle_loss, cross_entropy

class Bayesian(Model):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.loss_fn = nn.MSELoss(reduction='none')
        self.input_type = config['input_type']
        self.output_type = config['output_type']
        if self.input_type == 'image':
            self.x_encoder = COMPONENT[config['x_encoder']](config)
        elif self.input_type == 'vector':
            self.x_encoder = Mlp(config, input_dim=config['x_dim'])
        else:
            raise NotImplementedError

        # Create gaussian weights
        if self.input_type == 'image':
            self.hidden_dim = 256 * (config['x_h'] // 16) * (config['x_w'] // 16)
        elif self.input_type == 'vector':
            self.hidden_dim = config['hidden_dim']
        else:
            raise NotImplementedError
        # matrix normal distribution MN(mean, cov^-1, noise)
        # Mean in prior / torch.Size([h, c]) (zero initialization) / Bayesian update
        self.mean = torch.zeros(self.hidden_dim, self.config['y_dim'])
        # inverse covariance: torch.Size([h, h])
        # Cholesky decomposition parameterization (cov = R R^T, cov^-1 = S^T S, S = R^-1)
        # self.inv_cov = torch.diag(torch.ones(self.hidden_dim))
        self.dec_inv_cov = torch.diag(torch.ones(self.hidden_dim))
        # noise: torch.Size([c, c])
        # self.noise = torch.diag(torch.rand(self.config['y_dim']))

        # self.beta = nn.Parameter(torch.rand(config['y_dim'])/torch.exp(torch.tensor(1)), requires_grad=config['learnable_beta'])
        # noise term
        # self.beta = torch.rand(self.config['y_dim'])
        # self.beta = nn.Parameter(torch.rand(self.config['y_dim']), requires_grad=True)
        # self.bias = nn.Parameter(torch.zeros(self.config['y_dim']), requires_grad=True)

    def forward(self, train_x, train_y, test_x, test_y, evaluate):
        # self.alpha = torch.diag(torch.ones(self.hidden_dim)).repeat(self.config['y_dim'], 1, 1)
        # for name, p in self.named_parameters():
        #     print(name)
        if self.input_type == 'image':
            batch, train_num, c, h, w = train_x.shape
            batch, test_num, c, h, w = test_x.shape

            # Encode images
            x, _ = pack([
                rearrange(train_x, 'b l c h w -> (b l) c h w'),
                rearrange(test_x, 'b l c h w -> (b l) c h w'),
            ], '* c h w')
            x_enc = self.x_encoder(x)
            x_enc = rearrange(x_enc, 'bl c h w -> bl (c h w)')
            train_x_enc = rearrange(x_enc[:batch * train_num], '(b l) h -> b l h', b=batch, l=train_num)
            test_x_enc = rearrange(x_enc[batch * train_num:], '(b l) h -> b l h', b=batch, l=test_num)
        elif self.input_type == 'vector':
            batch, train_num, _ = train_x.shape
            batch, test_num, _ = test_x.shape

            # Encode x vectors
            x, _ = pack([
                rearrange(train_x, 'b l d -> (b l) d'),
                rearrange(test_x, 'b l d -> (b l) d'),
            ], '* d')
            x_enc = self.x_encoder(x)
            train_x_enc = rearrange(x_enc[:batch * train_num], '(b l) h -> b l h', b=batch, l=train_num)
            test_x_enc = rearrange(x_enc[batch * train_num:], '(b l) h -> b l h', b=batch, l=test_num)
        else:
            raise NotImplementedError
        
        assert train_x_enc.shape[-1] == self.hidden_dim
        assert test_x_enc.shape[-1] == self.hidden_dim
        
        mean = self.mean.data.to(train_x_enc.device) # cut retain_graph
        dec_inv_cov = self.dec_inv_cov.data.to(train_x_enc.device) # cut retain_graph
        inv_cov = dec_inv_cov.T @ dec_inv_cov
        Q = (torch.linalg.inv(inv_cov) @ mean).to(train_x_enc.device)
        with torch.enable_grad():
            # Forward train data sequentially
            for i in range(train_num):
                x_i = train_x_enc[:, i] # torch.Size([b, h]) / torch.Size([b, t, h])
                y_i = train_y[:, i] # torch.Size([b, c])
                logit, mean, inv_cov, Q = bayesian_forward_train(mean, inv_cov, Q, x_i, y_i)
            # Forward test data
            logit = bayesian_forward_test(mean, test_x_enc)
            meta_loss = self.loss_fn(logit, test_y).mean(-1) # torch.Size([16, 100, 50]) -> torch.Size([16, 100])
            meta_loss = rearrange(meta_loss, 'b l -> b l 1') 

        self.mean = mean
        self.dec_inv_cov = torch.linalg.cholesky(inv_cov).mH
        output = {
            'loss': meta_loss,
            'logit': logit.detach(),
        }
        
        return output

def bayesian_forward_train(mean, inv_cov, Q, x, y):
    # attach .detach() when it is not passed into loss.backward()
    # no inverse caculation
        # self.mean = self.mean.to(x.device) # torch.Size([h, c])
        # self.inv_cov = self.inv_cov.to(x.device) # torch.Size([h, h])
        # self.Q = self.Q.to(x.device) # torch.Size([h, c])

    inv_cov_x = x @ inv_cov.T
    x_inv_cov_x = torch.einsum('bh,bh->b', x, inv_cov_x)

    post_inv_cov = inv_cov - (1/(1+x_inv_cov_x).view(-1,1,1) * (inv_cov_x.unsqueeze(-1) * inv_cov_x.unsqueeze(1))).mean(0) # batch mean
    Q = (x.unsqueeze(-1) * y.unsqueeze(1)).mean(0) + Q
    post_mean = post_inv_cov @ Q
        
        # calculate one inverse matrix
        # com = torch.einsum('cth,bh->ctb', self.alpha, x)
        # rev_com = torch.einsum('bh,cht->cbt', x, self.alpha)
        # reg = torch.diag(self.eps * torch.ones(x.size()[0]).to(x.device)).repeat(y.size()[-1],1,1)
        # assert torch.equal(torch.einsum('bct,bh->cth', rev_com, x), torch.einsum('bh,ctb->cth', x, com))
        # inv = torch.linalg.inv(torch.exp(self.beta).view(-1,1,1) + torch.einsum('bh,chk->cbk', x, com))s
        # post_mean = torch.einsum('chb,bc->ch', torch.einsum('ctb,cbk->ctk', com, inv), y-self.bias)
        # post_cov = self.alpha - torch.einsum('ctb,cbh->cht', torch.einsum('ctb,cbk->ctk', com, inv), rev_com)
        
        # calculate two inverse matrices
        # inv_prior_cov = torch.linalg.inv(self.alpha)
        # inv_noise = 1/torch.exp(self.beta)
        # noise = torch.diag(self.beta)
        # x_dot_b = x.unsqueeze(0) * inv_noise.view(-1, 1, 1) # torch.Size([50, 16, 513])
        # post_cov = torch.linalg.inv(inv_prior_cov + torch.einsum('cbh,bt->cht', x_dot_b, x))
        # post_mean = torch.einsum('cth,ch->ct', post_cov, (torch.einsum('cht,ct->ch', inv_prior_cov, self.mean) + torch.einsum('cbh,bc->ch', x_dot_b, y-self.bias)))
        # self.mean = torch.einsum('cth,ahc->cat', self.alpha, torch.einsum('bah,bac->ahc', x, self.beta.view(1, 1, -1) * y)).sum(dim=1)

    assert mean.size() == post_mean.size()
    assert inv_cov.size() == post_inv_cov.size()
    # self.mean = post_mean.to(self.mean.device)
    # self.inv_cov = post_inv_cov.to(self.inv_cov.device)
    target = torch.einsum('hc,bh->bc', post_mean, x)
    return target, post_mean, post_inv_cov, Q

def bayesian_forward_test(mean, x):
    # self.mean = self.mean.to(x.device)
    # self.noise = self.noise.to(x.device)
    target = torch.einsum('hc,bth->btc', mean, x)
        # cov = (1 + torch.einsum('bth,hz,btz->bt', x, self.inv_cov, x).mean()) * self.noise
        # diff = (y - torch.einsum('hc,bth->btc', self.mean, x)).mean(0).mean(0)
        # loss = diff @ cov @ diff + torch.log(torch.linalg.det(cov))
    return target
