# -*- coding: utf-8 -*-

import sys
import os

# Import libraries, the code is built on PyTorch
import torch
import random
import warnings
import time

import multiprocessing
from functools import partial

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim

import multiprocessing as mp
from functools import partial

from cca_zoo.linear import CCA
from torch.utils.data import DataLoader, Dataset


warnings.filterwarnings("ignore")

from utils import *
from models import *
from estimators import *

# Check if CUDA is running
device = 'cuda'
if(not torch.cuda.is_available()):
    device = 'cpu'
print(device)

#############################
taskid = int(0) # change if on a cluster

tic_all = time.time()
        
# Define lists
mi_lst = np.arange(2,12,2)
dz_lst_cca = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,30,40,50,60,70,80,90,100,200,300,400,500]).astype(int)

###############################
# parameters

opt_params = {
    'epochs': 1,
    'n_iter': 20000,
    'batch_size': 128,
}

data_params = {
    'samples': int(opt_params['n_iter']*opt_params['batch_size']/5),
    'Nx': 500,
    'Ny': 500,
    'ncom': 10,
}
###############################
# Define a Teacher Model for X and Y
teacher_model_x = teacher(dz=data_params['ncom'], output_dim=data_params['Nx'])
teacher_model_y = teacher(dz=data_params['ncom'], output_dim=data_params['Ny'])

# Freeze
for param_x in teacher_model_x.parameters():
    param_x.requires_grad_(False)
for param_y in teacher_model_y.parameters():
    param_y.requires_grad_(False)
    
#############################
# Specify the folder name
folder_name = "Infinite_Data_Highdim"

# Empty arrays and do the run
mi_cca_opt = dict()

for mi_toUse in mi_lst:
    print('MI = '+str(mi_toUse), flush=True)
    # Make datasets
    # Copied dataset
    train_X_copied, train_Y_copied = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=data_params['samples'], info_tot=mi_toUse)
    train_X_copied, train_Y_copied = train_X_copied.numpy(), train_Y_copied.numpy()
    
    # Teacher dataset
    train_X_teacher, train_Y_teacher = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=data_params['samples'], info_tot=mi_toUse, mlp_x=teacher_model_x, mlp_y=teacher_model_y)
    train_X_teacher, train_Y_teacher = train_X_teacher.numpy(), train_Y_teacher.numpy()
    
    # cca Copied and teacher and noisy
    for dz in dz_lst_cca:
        tic = time.time()
        
        cca_mdl_copied = CCA(latent_dimensions=dz);
        cca_mdl_teacher = CCA(latent_dimensions=dz);
    
        cca_mdl_copied.fit((train_X_copied,train_Y_copied));
        cca_mdl_teacher.fit((train_X_teacher,train_Y_teacher));
    
        X_cca_copied, Y_cca_copied = cca_mdl_copied.transform((train_X_copied,train_Y_copied))
        X_cca_teacher, Y_cca_teacher = cca_mdl_teacher.transform((train_X_teacher,train_Y_teacher))
        
        mi_cca_opt[f"{mi_toUse}_{dz}_copied"] = mut_info_optimized(X_cca_copied,Y_cca_copied)
        mi_cca_opt[f"{mi_toUse}_{dz}_teacher"] = mut_info_optimized(X_cca_teacher,Y_cca_teacher)
    
    
        print(f"Time for CCA = {round(time.time()-tic)} for MI = {mi_toUse} for dz = {dz}, MI_copied = {round(mi_cca_opt[f'{mi_toUse}_{dz}_copied'],3)}, MI_teacher = {round(mi_cca_opt[f'{mi_toUse}_{dz}_teacher'],3)}", flush=True)


# Save the data
temp_name_cca = (f"CCA")

temp_name_cca_opt = 'opt_'+temp_name_cca

file_path_cca_opt = os.path.join(folder_name, f"{temp_name_cca_opt}")

np.save(file_path_cca_opt,mi_cca_opt)
    
print('Done!, Total time = '+str(round(time.time()-tic_all,3)), flush = True)