# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

import numpy as np

class Instance:
    def __init__(self, W, Z, theta, Gamma, instance_type, sigma_x=.1, sigma_y=1):
        self.W = W
        self.K = len(W)
        self.Z = Z
        if len(theta.shape) == 1:
            theta = theta.reshape(-1, 1)
        self.theta = theta
        self.Gamma = Gamma
        self.d = W.shape[0]
        self.opt = np.argmax((W@theta).flatten())
        self.opt_w = W[self.opt]
        self.type = instance_type
        self.sigma_y = sigma_y
        self.sigma_x = sigma_x
        if instance_type == 'compliance':
            assert np.allclose(Gamma.sum(axis=0), 1)
        elif instance_type == 'special_compliance':
            horizon = 100000
            Z_idx_sample = np.random.choice(self.K, p=np.ones(self.K)/self.K, size=horizon)        
            Z_sample = self.Z[Z_idx_sample]
            epsilon = np.random.normal(0, self.sigma_y, size=horizon)
            dist = np.vstack([np.abs([Z_idx_sample+epsilon-i]) for i in range(self.K)]).T
            X_idx = np.argmin(dist, axis=1)
            X_sample = self.W[X_idx]
            self.Gamma = np.linalg.inv(Z_sample.T@Z_sample)@Z_sample.T@X_sample
        if instance_type == 'compliance':
            self.subgaussian = 8*np.linalg.norm(theta, 2)**2 + 2*self.sigma_y/2
        elif instance_type == 'special_compliance':
            self.subgaussian = 8*np.linalg.norm(theta, 2)**2 + 2*self.sigma_y
        else:
            self.subgaussian = np.linalg.norm(theta, 2)**2 

            
