import random
import numpy as np


import torch
import time

seed = 42
# set seeds
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# set random seed
random.seed(seed)
np.random.seed(seed)

def main():
    lamb = 1.0
    d = 3500
    cov = torch.eye(d, dtype=torch.float64).cuda() * lamb
    cov_inv = torch.eye(d, dtype=torch.float64).cuda() * lamb ** -1
    hidden_mean = torch.zeros(d, dtype=torch.float64).cuda()

    hidden_mean_counter = 0
    all_hidden_states = []
    for g in range(1024):
        print('g', g)
        hidden_states = []
        for i in range(2):
            hidden_states.append(torch.randn(d, dtype=torch.float32).cuda() * 10)
            all_hidden_states.append(hidden_states[-1])

            if g > 0:
                bonuses_1 = (hidden_states[-1].to(torch.float64) @ cov_inv_centered_1 @ hidden_states[-1].to(torch.float64))
                # bonuses_2 = (hidden_states[-1] @ cov_inv_centered_2 @ hidden_states[-1])

                if torch.any(bonuses_1 < 0):
                    print('bonuses 1 is negative')
                    breakpoint()
                # if torch.any(bonuses_2 < 0):
                #     print('bonuses 2 is negative')
                #     exit()

                # print('bonuses 1', bonuses_1)
                # print('bonuses 2', bonuses_2)
    
        hidden_states = torch.stack(hidden_states)
        for hidden_state in hidden_states:
            chosen_samp = hidden_state.unsqueeze(1).to(torch.float64)
            middle_part = torch.inverse(1 + chosen_samp.t() @ cov_inv @ chosen_samp)
            cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv

            delta = hidden_state.to(torch.float64) - hidden_mean
            hidden_mean = hidden_mean + delta / (hidden_mean_counter + 1)
            hidden_mean_counter += 1

        all_hidden_states_tensor = torch.stack(all_hidden_states).to(torch.float64)
        # hidden_mean = torch.mean(all_hidden_states_tensor, dim=0)
        cov = cov + hidden_states.t().to(torch.float64) @ hidden_states.to(torch.float64)
        cov_centered = torch.eye(d, dtype=torch.float64).cuda() * lamb + (all_hidden_states_tensor - hidden_mean.to(torch.float64)).t() @ (all_hidden_states_tensor - hidden_mean.to(torch.float64))
        
        start = time.time()
        cov_inv_centered_1 = cov_inv.to(torch.float64) - (cov_inv.to(torch.float64) @ hidden_mean.unsqueeze(1).to(torch.float64) @ hidden_mean.unsqueeze(0).to(torch.float64) @ cov_inv.to(torch.float64)) / (-1/hidden_mean_counter + hidden_mean.t().to(torch.float64) @ cov_inv.to(torch.float64) @ hidden_mean.to(torch.float64))
        cov_inv_centered_1 = cov_inv_centered_1.to(torch.float64)
        print('denominator 1', (-1/hidden_mean_counter + hidden_mean.t().to(torch.float32) @ cov_inv.to(torch.float32) @ hidden_mean.to(torch.float32)))
        print('time 1', time.time() - start)

        # start = time.time()
        # cov_inv_centered_2 = cov_inv - (cov_inv @ hidden_mean.unsqueeze(1) @ hidden_mean.unsqueeze(0) @ cov_inv) / (-1 + hidden_mean.t() @ cov_inv @ hidden_mean)
        # denominator_2 = -1 + hidden_mean.t() @ cov_inv @ hidden_mean
        # print('denominator 2', denominator_2)
        # for j in range(hidden_mean_counter - 1):
        #     denominator_2 = (-1 + hidden_mean.t() @ cov_inv_centered_2 @ hidden_mean)
        #     cov_inv_centered_2 = cov_inv_centered_2 - cov_inv_centered_2 @ hidden_mean.unsqueeze(1) @ hidden_mean.unsqueeze(0) @ cov_inv_centered_2 / (-1 + hidden_mean.t() @ cov_inv_centered_2 @ hidden_mean)
        # print('denominator 2', denominator_2)
        # print('time 2', time.time() - start)

        approx_identity = cov_inv @ cov.to(torch.float64)
        print('error compared to identity', torch.norm(approx_identity - torch.eye(d, dtype=torch.float32).cuda()))
        print('error compared to identity centered 1', torch.norm(cov_inv_centered_1 @ cov_centered.to(torch.float64) - torch.eye(d, dtype=torch.float64).cuda()))
        # print('iterative mean vs. non-iterative mean', torch.norm(hidden_mean_it - hidden_mean))
        # print('error compared to identity centered 2', torch.norm(cov_inv_centered_2 @ cov_centered - torch.eye(d, dtype=torch.float32).cuda()))
        # breakpoint()

if __name__ == "__main__":
    main()