#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(123)

M = 4
N = 3
Y = np.random.normal(size=[M,N])
lam = 15.

svdy = np.linalg.svd(Y, full_matrices=False)
u = svdy[0][:,0:1]
v = svdy[2][0:1,:]
O = svdy[0][:,1:] @ np.diag(svdy[1][1:]) @ svdy[2][1:,:]
Y += 7.5e-2*np.random.normal(size=[M,N])


def cost_big(X, sigma2, bad = True):
    M,N = X.shape
    t1 = -np.log(sigma2)
    svd = np.linalg.svd(X)
    if bad:
        t2 = -lam*np.sum(svd[1])
    else:
        t2 = -lam*np.sum(svd[1])/np.sqrt(sigma2)
    t3 = -M*N*np.log(sigma2)
    t4 = -0.5*np.sum(np.square(X-Y))/sigma2
    #return t4
    return t1+t2+t3+t4

def cost_small(a, logsigma2, bad = True):
    X = a*u@v + O
    return cost_big(X,np.power(10.,logsigma2), bad)

ng = 50
svt = svdy[1][0]
ag = np.linspace(-0.5*svt,1.5*svt,num=ng)
sg = np.linspace(-3.25,1.75,num=ng)
agg = np.linspace(-7.00*svt,7.00*svt,num=ng)
sgg = np.linspace(-1.0,3.75,num=ng)

av = np.zeros(ng*ng)
sv = np.zeros(ng*ng)
ld = np.zeros(ng*ng)
avg = np.zeros(ng*ng)
svg = np.zeros(ng*ng)
ldg = np.zeros(ng*ng)
for i in range(ng):
    for j in range(ng):
        ind = i*ng + j
        av[ind] = ag[i]
        sv[ind] = sg[j]
        ld[ind] = cost_small(ag[i], sg[j])
        avg[ind] = agg[i]
        svg[ind] = sgg[j]
        ldg[ind] = cost_small(agg[i], sgg[j], bad = False)

fig = plt.figure(figsize=[5,3])

thresh = -130
ld = np.maximum(ld,-130)
ldg = np.maximum(ldg,-130)

plt.subplot(1,2,1)
plt.tricontourf(av,sv,ld)
plt.xlabel(r"$\sigma_1(X)$")
plt.ylabel(r"$\log\sigma^2$")
plt.title(r"$-\log P(X|\lambda,\sigma^2)=\lambda |X|_*$")

plt.subplot(1,2,2)
plt.tricontourf(avg, svg, ldg)
plt.xlabel(r"$\sigma_1(X)$")
plt.ylabel(r"$\log\sigma^2$")
plt.title(r"$-\log P(X|\lambda,\sigma^2)=\frac{\lambda}{\sqrt{\sigma^2}} |X|_*$")

plt.tight_layout()
plt.savefig("bimodal.pdf")
plt.close()
