import numpy as np
# import scipy as sp
import torch
import torch.nn as nn
# import torch.optim as optim
# from datetime import datetime
from sklearn.preprocessing import normalize
# from numpy.random import multivariate_normal
import time

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.nn.functional import dropout

from torch.distributions import constraints


class Linear_TS:
    # LinUCB compatible with multiple individual models
    def __init__(self, raw_d,
                 pooling_step=-1, beta_threshold=0.01, explore_var_coef=1,
                 l2_reg=1, model_num=1, p_val = 0.3,
                 dim_reduce_method="avg_pool", device="cuda", diag_flag=True):
        self.raw_d = raw_d
        self.device = device

        # Sample number for each individual Lin-UCB model
        self.N = torch.zeros(model_num)
        self.explore_var_coef = explore_var_coef
        self.diag_flag = diag_flag

        ###
        self.current_chechpoint = None
        self.beta_threshold = beta_threshold

        #
        if dim_reduce_method in ["avg_pool", "gaussian_proj", "sparse_gaussian_proj"]:
            self.dim_reduce_method = dim_reduce_method
        else:
            raise NotImplementedError
        #
        if pooling_step > 0:
            self.d = (raw_d // pooling_step) + 1
            print("--- [TS dim reduction]: ", dim_reduce_method, ", reduced to dim: ", self.d)
            #
            if dim_reduce_method == "avg_pool":
                self.pooling_module = nn.AvgPool1d(kernel_size=pooling_step, stride=pooling_step, padding=0, ceil_mode=True)
            elif dim_reduce_method == "gaussian_proj":
                self.pooling_module = nn.Linear(in_features=self.raw_d, out_features=self.d, bias=False, dtype=torch.float32)
                nn.init.normal_(self.pooling_module.weight, mean=0, std=1.0 / np.sqrt(self.d))
            elif dim_reduce_method == "sparse_gaussian_proj":
                print("P_val: ", p_val)
                #
                self.pooling_module = nn.Linear(in_features=self.raw_d, out_features=self.d, bias=False, dtype=torch.float32)
                nn.init.normal_(self.pooling_module.weight, mean=0, std=1.0 / np.sqrt(self.d))
                
                # Make the transform matrix sparse
                with torch.no_grad():
                    self.pooling_module.weight.copy_((1 - p_val) * dropout(self.pooling_module.weight, p=p_val, training=True, inplace=False))
        else:
            self.d = raw_d
        self.pooling_step = pooling_step

        ###
        self.S = {i: l2_reg * torch.eye(self.d) for i in range(model_num)}
        self.b = {i: torch.zeros(self.d) for i in range(model_num)}
        self.Sinv = {i: (1 / l2_reg) * torch.eye(self.d) for i in range(model_num)}
        self.theta = {i: torch.zeros(self.d) for i in range(model_num)}
        self.model_num = model_num
        self.l2_reg = l2_reg

    def get_ranking_score(self, items, i=0):
        if self.pooling_step > 0:
            items = self.pooling_module(items)
        exploit_score = torch.matmul(items, self.theta[i].reshape(-1, 1))

        return exploit_score.reshape(-1, )

    def store_info(self, x, y, i=0):
        start_time = time.time()
        print("Begin store info, current samples: ", self.N[i])
        # x -> d x 1; y -> scalar
        if self.pooling_step > 0:
            x = self.pooling_module(x).view(-1, )
        #
        self.b[i] += y * x
        self.N[i] += 1
        #
        self.S[i] += torch.outer(x, x)
        print("[Outer prod]: ", time.time() - start_time)

        # Using Sherman-Morrison to calculate inverse
        self.Sinv[i], self.theta[i] = self._update_inverse(self.b[i], x)
        print("[Inverse]: ", time.time() - start_time)

    def _update_inverse(self, b, x, i=0):
        inv_S = torch.linalg.inv(self.S[i])
        #
        # old_S_inv = self.Sinv[i]
        # inv_S = self.sherman_morrison_inverse(previous_S_inv=old_S_inv, x=x)
        #
        mean_theta = torch.matmul(inv_S, b).reshape(-1, )
        #
        cpu_theta = mean_theta.detach().cpu()
        cpu_inv_S = inv_S.detach().cpu()
        #
        if self.diag_flag:
            # Diag elements only...
            theta = torch.normal(mean=cpu_theta, std=self.explore_var_coef * torch.diag(cpu_inv_S))
        else:
            try:
                mn = MultivariateNormal(loc=cpu_theta, covariance_matrix=self.explore_var_coef * cpu_inv_S)
                theta = mn.sample()
            except Exception:
                # In case of positive definite check
                print("=" * 30)
                print("Positive Definite Cov. Matrix Exception.")
                theta = torch.normal(mean=cpu_theta, std=self.explore_var_coef * torch.diag(cpu_inv_S))
        return inv_S, theta

    ###
    @staticmethod
    def sherman_morrison_inverse(previous_S_inv, x):
        outer_prod_m = torch.outer(x, x)
        print("[Frac term]")
        frac_term = \
            ( torch.matmul(previous_S_inv, torch.matmul(outer_prod_m, previous_S_inv)) ) / ( 1 + (torch.matmul(x.reshape(1, -1), torch.matmul(previous_S_inv, x.reshape(-1, 1)))) )
        print("[Frac term done]")
        return previous_S_inv - frac_term

    #####
    def get_sampled_theta_val(self, i=0):
        start_time = time.time()
        with torch.no_grad():
            if self.pooling_step > 0:
                if self.dim_reduce_method == "avg_pool":
                    expanded_theta = self.theta[i].reshape(-1, 1).expand(-1, self.pooling_step)
                    expanded_theta = expanded_theta.reshape(-1, )[:self.raw_d]
                    print("[Sampling theta]: ", time.time() - start_time)
                    return expanded_theta
                elif self.dim_reduce_method == "gaussian_proj":
                    low_dim_theta = self.theta[i].reshape(1, -1)
                    output_theta = self.reverse_proj_module(low_dim_theta)
                    return output_theta
                elif self.dim_reduce_method == "sparse_gaussian_proj":
                    low_dim_theta = self.theta[i].reshape(1, -1)
                    output_theta = self.reverse_proj_module(low_dim_theta)
                    return output_theta
            else:
                this_theta = self.theta[i].reshape(-1, )
                print("[Sampling theta]: ", time.time() - start_time)
                return this_theta

    def refresh_collected_data(self):
        l2_reg, d, model_num = self.l2_reg, self.d, self.model_num
        ###
        self.S = {i: l2_reg * torch.eye(d) for i in range(model_num)}
        self.b = {i: torch.zeros(d) for i in range(model_num)}
        self.Sinv = {i: (1 / l2_reg) * torch.eye(d) for i in range(model_num)}
        self.theta = {i: torch.zeros(d) for i in range(model_num)}
        #
        self.N = torch.zeros(model_num)

    def check_and_update_checkpoint(self, these_parameters):
        assert these_parameters.shape[0] == self.raw_d
        if self.current_chechpoint is None:
            self.current_chechpoint = these_parameters.view(-1, )
            return True
        else:
            parameters = these_parameters.view(-1, )
            distance = torch.linalg.vector_norm((self.current_chechpoint - parameters), ord=2)
            #
            print("-"*30)
            print(f"Current distance: {distance}. The beta_threshold: {self.beta_threshold}")
            print("-"*30)
            #
            if distance > self.beta_threshold:
                self.current_chechpoint = these_parameters.view(-1, )
                self.refresh_collected_data()
                return True
            else:
                return False


