# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 20:13:13 2024

@author: User
"""
import torch
from torch.utils.data import Dataset
from torch.distributions import MultivariateNormal
import numpy as np 
from copy import deepcopy

def transform_points(data,transform,params_dict):
    # print(transform)
    if transform == 'none':
        return data
    if transform == 'sigmoid':
        return torch.nn.functional.sigmoid(data)
    if transform == 'concat_self':
        params = params_dict[transform]
        data_orig = deepcopy(data) 
        for k in range(params[0]):
            data = torch.cat((data,data_orig),axis=1)
        return data
    if transform == 'concat_self_noisy':
        params = params_dict[transform]
        data_orig = deepcopy(data) 
        for k in range(params[0]):
            data = torch.cat((data,torch.rand_like(data_orig)*params[1]),axis=1)
        return data
    if transform == 'cube':
        data = data**3.0
        return data
    if transform == 'randmat':
        
        flag = 0
        while flag == 0:
            try:
                rand_mat = torch.rand(1)*torch.rand(data.shape[1],data.shape[1])
                torch.linalg.inv(rand_mat)
                flag = 1
            except:
                pass 
        return torch.matmul(data.double(), rand_mat.double())
    if transform == 'scale':
        params = params_dict[transform]
        data = params*data
        return data
    
    
class MultivariateNormalDataset(Dataset):
    def __init__(self, N, dim, rho,params_dict,transforms_x=['none'],transforms_y=['none']):
        self.N = N
        self.rho = rho
        self.dim = dim
        # print(transforms_x)
        self.x_transforms = transforms_x 
        
        self.y_transforms = transforms_y
        self.params_dict = params_dict
        self.dist = self.build_dist
        
        self.x = self.dist.sample((N, ))
        self.y = self.x[:,dim:]
        self.x = self.x[:,:dim]        
        
        self.transform_both() 
        # self.distractor_x = 
        self.dim = dim
        
    

    def __getitem__(self, ix):
        a, b = self.x[ix, 0:self.dim], self.x[ix, self.dim:2 * self.dim]
        return a, b
    
    def transform_both(self):
        for iter in range(len(self.y_transforms)):
            self.y = transform_points(self.y,self.y_transforms[iter],self.params_dict)
        
        for iter in range(len(self.x_transforms)):
            self.x = transform_points(self.x,self.x_transforms[iter],self.params_dict)
            
        self.x = self.x.numpy()
        self.y = self.y.numpy()

    def __len__(self):
        return self.N

    @property
    def build_dist(self):
        mu = torch.zeros(2 * self.dim)
        dist = MultivariateNormal(mu, self.cov_matrix)
        return dist

    @property
    def cov_matrix(self):
        cov = torch.zeros((2 * self.dim, 2 * self.dim))
        cov[torch.arange(self.dim), torch.arange(self.dim, 2 * self.dim)] = self.rho
        cov[torch.arange(self.dim, 2 * self.dim), torch.arange(self.dim)] = self.rho
        cov[torch.arange(2 * self.dim), torch.arange(2 * self.dim)] = 1.0
        return cov

    @property
    def true_mi(self):
        return -0.5 * np.log(np.linalg.det(self.cov_matrix.data.numpy()))
    
    
    
    
class GaussianAdditionDataset(Dataset):
    def __init__(self, N, dim, SNR,params_dict,transforms_x=['none'],transforms_y=['none'],mode='normal'):
        self.N = N
        self.dim = dim
        self.SNR = SNR
        self.params_dict = params_dict
        self.x_transforms = transforms_x
        # self.x_transform_params = transforms_x[1]
        
        self.y_transforms = transforms_y 
        # self.y_transform_params = transforms_y[1]

        
        self.x = np.random.normal(0., 1, [self.N, self.dim])    
        if mode == 'normal':
            self.y = self.x + np.random.normal(0., np.sqrt(1/self.SNR), [self.N, self.dim])
        elif mode == 'scale':
            self.y = (np.sqrt(self.SNR)*self.x) + np.random.normal(0., 1.0, [self.N, self.dim])
        self.x = torch.from_numpy(self.x)
        self.y = torch.from_numpy(self.y)
        
        self.transform_both() 
        # self.distractor_x = 
        
    
    def transform_both(self):
        for iter in range(len(self.y_transforms)):
            self.y = transform_points(self.y,self.y_transforms[iter],self.params_dict)
        
        for iter in range(len(self.x_transforms)):
            self.x = transform_points(self.x,self.x_transforms[iter],self.params_dict)
            
        self.x = self.x.numpy()
        self.y = self.y.numpy()

    def __len__(self):
        return self.N


    @property
    def true_mi(self):
        return self.dim*0.5*np.log(1 + self.SNR)
        
    