import torch
import torch.nn as nn
import torchvision
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from torch.utils.data import TensorDataset
from tqdm.auto import tqdm
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from utils import *
from flow import *
import scipy.stats as st
import random
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import matplotlib.colors as mcolors
import os
import torch.multiprocessing
from pathlib import Path

def gauss_2d(n, mu, sigma):
    x = np.random.multivariate_normal(mean=mu, cov=sigma, size=(n))
    return x[:,0], x[:,1]

def confidence_ellipse(x, y, ax, n_std=3.0, facecolor='none', mean = None, cov=None, **kwargs):
    """
    Create a plot of the covariance confidence ellipse of *x* and *y*.

    Parameters
    ----------
    x, y : array-like, shape (n, )
        Input data.

    ax : matplotlib.axes.Axes
        The axes object to draw the ellipse into.

    n_std : float
        The number of standard deviations to determine the ellipse's radiuses.

    **kwargs
        Forwarded to `~matplotlib.patches.Ellipse`

    Returns
    -------
    matplotlib.patches.Ellipse
    """
    if x.size != y.size:
        raise ValueError("x and y must be the same size")

    if cov is None:
        cov = np.cov(x, y)
    pearson = cov[0, 1]/np.sqrt(cov[0, 0] * cov[1, 1])
    # Using a special case to obtain the eigenvalues of this
    # two-dimensionl dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      alpha = 0.2, fc = facecolor, ec="None", **kwargs)

    scale_x = np.sqrt(cov[0, 0]) * n_std
    if mean is None:
        mean_x = np.mean(x)
        mean_y = np.mean(y)
    else:
        mean_x = mean[0]
        mean_y = mean[1]

    # calculating the stdandard deviation of y ...
    scale_y = np.sqrt(cov[1, 1]) * n_std

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean_x, mean_y)

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)

def plot_data(source, target, labels, labels2, source_color, target_color, numclass, numclass2, source_mean, source_cov, target_mean, target_cov, path = None, step = None, xlim=None, ylim=None):
    fig, ax = plt.subplots(figsize=(10, 10))
    scale = 0.2
    ax.scatter(source[0], source[1], s=3, c=source_color)
    for i in range(numclass):

        mask = labels==i
        if i==0:
            confidence_ellipse(source[0][mask], source[1][mask], ax, n_std=1*scale,
                           label='source', facecolor=source_color, mean =source_mean[i], cov=source_cov[i])
        else:
            confidence_ellipse(source[0][mask], source[1][mask], ax, n_std=1*scale,
                            facecolor=source_color, mean =source_mean[i], cov=source_cov[i])
        confidence_ellipse(source[0][mask], source[1][mask], ax, n_std=2*scale,
                            facecolor=source_color, mean =source_mean[i], cov=source_cov[i])
        confidence_ellipse(source[0][mask], source[1][mask], ax, n_std=3*scale,
                            facecolor=source_color, mean =source_mean[i], cov=source_cov[i])

    ax.scatter(target[0], target[1], s=3, c=target_color)

    for i in range(numclass2):
        mask = labels2==i

        if i==0:
            confidence_ellipse(target[0][mask], target[1][mask], ax, n_std=1*scale,
                           label='target', facecolor=target_color, mean =target_mean[i], cov=target_cov[i])
        else:
            confidence_ellipse(target[0][mask], target[1][mask], ax, n_std=1*scale,
                            facecolor=target_color, mean =target_mean[i], cov=target_cov[i])
        confidence_ellipse(target[0][mask], target[1][mask], ax, n_std=2*scale,
                            facecolor=target_color, mean =target_mean[i], cov=target_cov[i])
        confidence_ellipse(target[0][mask], target[1][mask], ax, n_std=3*scale,
                            facecolor=target_color, mean =target_mean[i], cov=target_cov[i])
    ax.legend()
    if xlim!=None:
        plt.xlim([xlim[0],xlim[1]])
    if ylim!=None:
        plt.ylim([ylim[0],ylim[1]])
    plt.savefig(path+str(step)+".png")
    return ax.get_xlim(), ax.get_ylim()


color_palette = list(mcolors.TABLEAU_COLORS.values())
source_color = 'lightcoral'
target_color = 'turquoise'

num_images = 40
num_classes = 2
num_images2 = 20
num_classes2 = 4

data_path = "./data/2gauss_4gauss.tar"
device = 'cpu'
data = torch.load(data_path,map_location=device)
source_data = data["source_data"].to(device)
source_mean = data["source_mean"].to(device)
source_cov = data["source_cov"].to(device)
target_data = data["target_data"].to(device)
target_data_cpu = data["target_data"].cpu().numpy()
target_mean = data["target_mean"].to(device)
target_cov = data["target_cov"].to(device)
labels = data["labels"].cpu().numpy()
labels2 = data["labels2"]

alpha = 0.3
beta = 0.1
gamma = 0.5
step_size_0 = 0.03 #adaptive?
step_size = step_size_0
momentum=0.9
rmsprop = 0.9
momentum_flag = False
rmsprop_flag = True
adam_flag = False
adam_beta1 = 0.9
adam_beta2 = 0.999

noise_beta = 0.1

num_steps = 2500
eps=1e-7

total_images = num_images*num_classes
mini_batch = total_images/3
# grad_m = compute_gradm(beta, mean, K_mean)
# grad_c = compute_gradcov(gamma, cov, K_cov)
cov_shape = (total_images,source_cov[0].shape[0], source_cov[0].shape[1])

x_tau = torch.zeros(source_data.shape, device=device) # x_tau size: 2*40
mean_tau = torch.zeros(source_data.shape, device=device) # 2*40
cov_tau = torch.zeros(cov_shape, device=device) # cov_tau size: total_images*2*2

last_x_grad = torch.zeros_like(x_tau)
last_mean_grad = torch.zeros_like(mean_tau)
last_cov_grad = torch.zeros_like(cov_tau)

if adam_flag:
    last_x_v = torch.zeros_like(x_tau)
    last_mean_v = torch.zeros_like(mean_tau)
    last_cov_v = torch.zeros_like(cov_tau)

objectives = []

identity = torch.eye(source_cov[0].shape[0],source_cov[0].shape[1], device=device)
rms_x = 0.0
rms_mean = 0.0
samemean=False
target_obj = compute_target_obj(target_data, target_mean, target_cov, alpha, beta, gamma, labels, size=total_images)

from datetime import datetime
import pytz

now = datetime.now()
dt_string = now.strftime("%d_%m_%H_%M")
print(dt_string)
path = "results/2gauss_4gauss/"+dt_string+"/"

Path(path).mkdir(parents=True, exist_ok=True)

with open(path+"setup.txt", "w") as f:
    f.write(f"data: {data_path}\n")
    f.write(f"number of images for each class: {num_images}, number of classes: {num_classes}, minibatch: {mini_batch}\n")
    f.write(f"alpha: {alpha}, beta: {beta}, gamma: {gamma}, step_size: {step_size}\n")
    if momentum_flag:
        f.write(f"use momentum, {momentum}\n")
    elif rmsprop_flag:
        f.write(f"use rmsprop, {rmsprop}\n")
    elif adam_flag:
        f.write(f"use adam, beta1: {adam_beta1}, beta2: {adam_beta2}\n")
    f.write(f"noise level: {noise_beta}\n")

indices = np.arange(total_images)
np.random.shuffle(indices)
with torch.no_grad():
    for k in range(num_steps):
        if k%3==0:
            batches = indices[:int(total_images/3)].copy()
        elif k%3==1:
            batches = indices[int(total_images/3):int(2*total_images/3)].copy()
        else:
            batches = indices[int(2*total_images/3):].copy()
            np.random.shuffle(indices)

        if k%500==0 and k!=0:
            noise_beta = noise_beta/2 #decrease noise level by half every 1000 steps
        step_size = step_size_0/(1+k/500)
        if k ==0:
            x_tau = source_data
            for i in range(total_images):
                label =labels[i]
                mean_tau[:,i] = source_mean[label]
                cov_tau[i] = source_cov[label]
            xlim, ylim = plot_data(x_tau.cpu().numpy(), target_data_cpu, labels, labels2, source_color, target_color, num_classes, num_classes2, source_mean = source_mean.cpu().numpy(), source_cov = source_cov.cpu().numpy(),target_mean = target_mean.cpu().numpy(), target_cov = target_cov.cpu().numpy(), path = path, step = "init")

        adam_scale1 = 1-np.power(adam_beta1,k+1)
        adam_scale2 = 1-np.power(adam_beta2,k+1)
        x_tau = x_tau.to(device)
        mean_tau = mean_tau.to(device)
        cov_tau = cov_tau.to(device)

        for i in batches:
            psi_bar_x = torch.zeros(x_tau[:,i].shape, device=device)
            psi_bar_mean = torch.zeros(mean_tau[:,i].shape, device=device)
            psi_bar_cov = torch.zeros(cov_tau[i].shape, device=device)

            psi_x = torch.zeros(x_tau[:,i].shape, device=device)
            psi_mean = torch.zeros(mean_tau[:,i].shape, device=device)
            psi_cov = torch.zeros(cov_tau[i].shape, device=device)

            noise = noise_beta*torch.randn(2).to(device)
#             noise_mean = noise_beta*torch.randn(2).to(device)
            noise_cov = torch.randn(2,2, device = device)*noise_beta
            noise_cov[1,0] = noise_cov[0,1]

            for j in batches:
                label_j = labels2[j]
                psi1, psi2, psi3 = gradient(alpha, beta, gamma, x_tau[:,i]+noise, mean_tau[:,i], cov_tau[i], target_data[:,j], target_mean[label_j], target_cov[label_j])

                psi_bar_x += psi1.squeeze() #(1,100)->(100,)
                psi_bar_mean += psi2.squeeze()
                psi_bar_cov += psi3.squeeze()
            for j in range(total_images):
                psi1, psi2, psi3 = gradient(alpha, beta, gamma, x_tau[:,i]+noise, mean_tau[:,i], cov_tau[i], x_tau[:,j], mean_tau[:,j], cov_tau[j])

                psi_x += psi1.squeeze()
                psi_mean += psi2.squeeze()
                psi_cov += psi3.squeeze()

            psi_bar_x = psi_bar_x/mini_batch
            psi_bar_mean = psi_bar_mean/mini_batch
            psi_bar_cov = psi_bar_cov/mini_batch

            psi_x = psi_x/total_images
            psi_mean = psi_mean/total_images
            psi_cov = psi_cov/total_images

            # update x
            if momentum_flag:
                x_tau[:,i] += step_size*(psi_bar_x - psi_x)+momentum*last_x_grad[:,i]
            elif rmsprop_flag:
                if k==0:
                    last_x_grad[:,i]=torch.norm(psi_bar_x - psi_x)
                else:
                    last_x_grad[:,i]=0.9*last_x_grad[:,i]+0.1*torch.square(psi_bar_x - psi_x)
                x_tau[:,i] += step_size*torch.div(psi_bar_x - psi_x, torch.sqrt(last_x_grad[:,i]+eps))
            else:
                last_x_grad[:,i]=adam_beta1*last_x_grad[:,i]+(1-adam_beta1)*(psi_bar_x - psi_x)
                last_x_v[:,i]=adam_beta2*last_x_v[:,i]+(1-adam_beta2)*torch.square(psi_bar_x - psi_x)

                last_x_grad_scaled = last_x_grad[:,i]/adam_scale1
                last_x_v_scaled = last_x_v[:,i]/adam_scale2

                x_tau[:,i] += step_size*torch.div(last_x_grad_scaled, torch.sqrt(last_x_v_scaled)+eps)

            # update mean
            if momentum_flag:
                mean_tau[:,i] += step_size*(psi_bar_mean - psi_mean)+momentum*last_mean_grad[:,i]
            elif rmsprop_flag:
                if k==0:
                    last_mean_grad[:,i]=torch.norm(psi_bar_mean - psi_mean)
                else:
                    last_mean_grad[:,i]=rmsprop*last_mean_grad[:,i]+(1-rmsprop)*torch.square(psi_bar_mean - psi_mean)
                mean_tau[:,i] += step_size*torch.div(psi_bar_mean - psi_mean, torch.sqrt(last_mean_grad[:,i]+eps))
            else:
                last_mean_grad[:,i]=adam_beta1*last_mean_grad[:,i]+(1-adam_beta1)*(psi_bar_mean - psi_mean)
                last_mean_v[:,i]=adam_beta2*last_mean_v[:,i]+(1-adam_beta2)*torch.square(psi_bar_mean - psi_mean)

                last_mean_grad_scaled = last_mean_grad[:,i]/adam_scale1
                last_mean_v_scaled = last_mean_v[:,i]/adam_scale2
                mean_tau[:,i] += step_size*torch.div(last_mean_grad_scaled, torch.sqrt(last_mean_v_scaled)+eps)

            lya_sol = solve_lyapunov_cpu(cov_tau[i], psi_bar_cov-psi_cov).to(device)
            #use momentum on cov
            cov_tau[i] =(identity+step_size*lya_sol+momentum*last_cov_grad[i])@cov_tau[i]@(identity+step_size*lya_sol+momentum*last_cov_grad[i])
            last_cov_grad[i]=lya_sol

        ob = compute_objective(x_tau, mean_tau, cov_tau, target_data, target_mean, target_cov, alpha, beta, gamma, labels2, samemean, total_images)
        objectives.append(0.5*(ob+target_obj))

        if k%20==0:
            source_mean_cpu = mean_tau.float().cpu().numpy()
            source_cov_cpu = cov_tau.float().cpu().numpy()

            print("step: ",k)
            source_mean_plt = []
            source_cov_plt = []
            for c in range(num_classes):
                mask = labels==c
                source_mean_plt.append(np.median(source_mean_cpu[:,mask],axis=1))
                source_cov_plt.append(np.median(source_cov_cpu[mask], axis=1))
            source_mean_plt = np.array(source_mean_plt)
            source_cov_plt = np.array(source_cov_plt)

            source_cov_plt = source_cov.cpu().numpy()
            results_tau = {}
            results_tau["source_data"] = x_tau
            results_tau["source_mean"] = mean_tau
            results_tau["source_cov"] = cov_tau
            results_tau["labels"] = labels
            data_path = path+str(k)+".tar"

            torch.save(results_tau, data_path)
            plot_data(x_tau.cpu().numpy(), target_data_cpu, labels, labels2, source_color, target_color, num_classes, num_classes2, source_mean = source_mean_plt, source_cov = source_cov_plt, target_mean = target_mean.cpu().numpy(), target_cov = target_cov.cpu().numpy(), path = path, step = k)
            plt.clf()
            plt.plot(
                range(int(len(objectives))),
                torch.Tensor(objectives),

            )
            plt.ylabel('objective')
            plt.xlabel('steps')
            plt.savefig(path+str(k)+"_obj.png")
            plt.close('all')
