# %% import

import time
import numpy as np
import matplotlib.pyplot as plt
from utils2 import sampled_sphere
from utils2 import Tukey_Depth, Projection_Depth, SW, Sinkhorn
from DW import DW
import ot
import csv
from math import *
from sklearn.datasets import make_circles
from SRW import SubspaceRobustWasserstein
from Optimization.projectedascent import ProjectedGradientAscent
from Optimization.frankwolfe import FrankWolfe
import RobustOT
import torch

# %% import
def T(x,d,dim=2):
    assert dim <= d
    assert dim >= 1
    assert dim == int(dim)
    return x + 2*np.sign(x)*np.array(dim*[1]+(d-dim)*[0])
def fragmented_hypercube(n,d,dim):
    assert dim <= d
    assert dim >= 1
    assert dim == int(dim)
    
    a = (1./n) * np.ones(n)
    b = (1./n) * np.ones(n)

    # First measure : uniform on the hypercube
    X = np.random.uniform(-1, 1, size=(n,d))

    # Second measure : fragmentation
    Y = T(np.random.uniform(-1, 1, size=(n,d)), d, dim)
    
    return a,b,X,Y


def experiment(data="gaussian", K=1000, eps=0.2):

    prop_anom = np.linspace(0,0.2,20)
    n_rep = 100
    n_a = len(prop_anom)    
    DRW_T = np.zeros((n_rep, n_a))
    DRW_PD = np.zeros((n_rep, n_a))
    Sliced = np.zeros((n_rep, n_a))
    ROBOT = np.zeros((n_rep, n_a))
    SRW = np.zeros((n_rep, n_a))
    Wass = np.zeros((n_rep, n_a))
    RobustOTG = np.zeros((n_rep, n_a))
    for s in range(n_rep):
        print(f"Fitting model... {s/n_rep:6.1%}\r", end='', flush=True)
        k = 0
        for a in prop_anom:



            if (data == 'gaussian'):
                u_x = np.zeros(2)
                u_y = np.zeros(2) + 10
                cov_x = np.array([[1,0], [0, 1]])
                cov_y = np.array([[1,0], [0, 1]])
                n = 500
                m = 500
                w_X = np.zeros(n) + 1 / n
                w_Y = np.zeros(m) + 1 / m


                X = np.random.multivariate_normal(u_x, cov_x, size=n)
                Y = np.random.multivariate_normal(u_y, cov_y, size=m) 
                X_temp = X.copy()
                sigma = np.random.choice(np.arange(n),size=int(a*n))
                l = np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1,0],[0,1]]), size=int(a*n))
                for i in range(len(l)):
                    X_temp[sigma[i],:] = [-10 + 30*np.random.rand(),-10 + 30*np.random.rand() ]
            elif (data == 'circles'):
                Z, y = make_circles(n_samples=1000, factor=0.2, noise=0.02)
                X = Z[np.where(y==1)[0]]
                Y = Z[np.where(y==0)[0]] 
                n,d = X.shape
                m, d = Y.shape
                w_X = np.zeros(n) + 1 / n
                w_Y = np.zeros(m) + 1 / m
                sigma = np.random.choice(np.arange(n),size=int(a*n))
                X_temp = X.copy()
                l = np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1,0],[0,1]]), size=int(a*n))
                for i in range(len(l)):
                    X_temp[sigma[i],:] += (  np.random.rand()) *  (l[i] / np.linalg.norm(l[i] ))
            else:
                w_X, w_Y, X, Y = fragmented_hypercube(500,2,2)
                n,d = X.shape
                m, d = Y.shape
                sigma = np.random.choice(np.arange(n),size=int(a*n))
                l = np.random.multivariate_normal(mean=np.array([0,0]), cov=np.array([[1,0],[0,1]]), size=int(a*n))
                X_temp = X.copy()
                for i in range(len(l)):
                    X_temp[sigma[i],:] = [-4 + 8*np.random.rand(),-4 + 8*np.random.rand() ]

        
            M = ot.dist(X_temp,Y)
            Wass[s,k] = ot.emd2(w_X,w_Y, M=M)
            Sliced[s,k] = SW(X_temp,Y, ndirs=K, p=2, max_sliced=False)
            DRW_T[s,k] = DW(X_temp,Y, ndirs=K, n_alpha=20, eps=eps)
            DRW_PD[s,k] = DW(X_temp,Y, ndirs=K, n_alpha=20, data_depth='Projection', eps=eps)
            
            truncation_parameter = np.max(M) / 2
            rows, cols = np.where(M > truncation_parameter)
            M2 = M.copy()
            M2[rows, cols] = truncation_parameter
            ROBOT[s,k] = ot.emd2(w_X,w_Y, M=M2)
            
            ones = np.ones((n,n))
            C = np.diag(np.diag(X_temp.dot(X_temp.T))).dot(ones) + ones.dot(np.diag(np.diag(Y.dot(Y.T)))) - 2*X_temp.dot(Y.T)
            step_size_0 = 1./np.max(C)
            params = {'reg':0, 'step_size_0':step_size_0, 'max_iter':100, 'threshold':0.01,
              'max_iter_sinkhorn':100, 'threshold_sinkhorn':1e-3, 'use_gpu':False}
            algo = ProjectedGradientAscent(**params)
            SRW_PGA = SubspaceRobustWasserstein(X_temp, Y, w_X, w_Y, algo, k=1)
            SRW_PGA.run()
            SRW[s,k] = SRW_PGA.get_value()
            
            ROT = RobustOT.ROTSolver(X_temp,Y)
            a,P = ROT.solve()
            RobustOTG[s,k] = a
            
            k += 1
    return DRW_T, DRW_PD, Sliced, SRW, ROBOT, RobustOTG, Wass, X_temp, Y



DRW_T, DRW_PD, Sliced, SRW, ROBOT, RobustOTG, Wass, X_temp, Y = experiment(data='frag')

Robustness_frag = torch.cat((torch.FloatTensor(DRW_T), 
                             torch.FloatTensor(DRW_PD),
                             torch.FloatTensor(Sliced),
                             torch.FloatTensor(SRW),
                             torch.FloatTensor(ROBOT),
                             torch.FloatTensor(RobustOTG),
                             torch.FloatTensor(Wass),), axis=0)
                             
torch.save(Robustness_frag,'results/Robustness_frag.pt')

