import numpy as np
import argparse
import pickle
import os
import time
import torch
import pandas as pd 
import scipy as sp
import torch.nn as nn
import torch.optim as optim
import random
from collections import defaultdict
import scipy.io as sio
import scipy.sparse as spp
import scipy as sp
from sklearn.preprocessing import normalize
import json
from math import sqrt
from sys import exit

arg_size = 1
arg_shuffle = 1
arg_seed = 0
arg_nu = 1
arg_lambda = 0.0001
arg_hidden = 100

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
device = torch.device(dev)
print(dev)


# class Network(nn.Module):
#     def __init__(self, dim, hidden_size=100):
#         super(Network, self).__init__()
#         self.fc1 = nn.Linear(dim, hidden_size)
#         self.activate = nn.ReLU()
#         self.fc2 = nn.Linear(hidden_size, 1)
#         self.output = nn.Sigmoid()
#         self.fc3 = nn.Linear(50,25)
#         self.fc4 = nn.Linear(25,1)
#     def forward(self, x):
#         return self.output(self.fc2(self.activate(self.fc1(x))))

class Network(nn.Module):
    def __init__(self, dim, hidden_size=100):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(dim, hidden_size)
        self.activate = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, 1)
        self.output = nn.Sigmoid()
    def forward(self, x):
        return self.output(self.fc2(self.activate(self.fc1(x))))
    


class ConservativeFastCB:
    def __init__(self,dim, n_arm,gamma=50, lr = 0.001, m=100,alpha = 0.1):
        self.K = n_arm 
        self.m = m
        self.d = dim
        
        self.current_loss=0
        self.t=1
        self.rewards=[]
        self.context_list = []
        self.mu = self.K
        self.gamma = gamma
        self.lr = lr
        self.estimator=Network(self.d, hidden_size=m).to(device)
        self.optimizer = torch.optim.SGD(self.estimator.parameters(), lr = 0.005, weight_decay=0.)
        self.alpha = alpha

        self.all_rewards = []
        self.baseline_rewards = []

        
        
        
    def select(self, context, rewards, true_rewards, baseline_arm, t):
        self.features=torch.from_numpy(context).float().to(device)
        f_l = []
        for k in range(self.K):
            f=self.estimator(self.features[k])
            f_l.append(f.item())
        #print(self.features[k])
        arm_best = np.argmin(f_l)
        p_l = [0]*self.K
        sum_p = 0.0 
        for k in range(self.K):
            if k != arm_best:
                # gamma = self.K
                if f_l[arm_best] == 0:
                    p_l[k] = 0
                else:
                    p_l[k] = (f_l[arm_best])/(self.K*(f_l[arm_best]) + self.gamma * (f_l[k] - f_l[arm_best]))
                    # print(f'{f_l[arm_best]=},{f_l[k]=},{k=},{p_l[k]=}')
                sum_p += p_l[k]
        p_l[arm_best] = 1 - sum_p
        arm_to_pull=np.random.choice(np.arange(self.K), p=p_l)
        self.baseline_rewards.append(true_rewards[baseline_arm])
        if f_l[arm_to_pull] + np.sum(self.all_rewards) + np.sqrt(t  * np.log(1/0.1)) <= (1+self.alpha)*sum(self.baseline_rewards):
            self.all_rewards.append(f_l[arm_to_pull])
            #self.rewards.append(rewards[arm_to_pull])
            
            return arm_to_pull
        
        #self.all_rewards.append(rewards[arm_to_pull])
        self.all_rewards.append(true_rewards[baseline_arm])
        return 'baseline'#,f_l[arm_to_pull]#, sigma_l[arm_to_pull]
    
    def update(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        new_context = torch.from_numpy(context.reshape(1, -1)).float().to(device)
        self.rewards.append(reward)
        self.t+=1


    def train(self, t):
        # if t==100:
        #     self.optimizer = torch.optim.SGD(self.estimator.parameters(), lr = 0.005, weight_decay=0.)
        # if t==500:
        #     self.optimizer = torch.optim.SGD(self.estimator.parameters(), lr = 0.001, weight_decay=0.)
        # if t==1000:
        #     self.optimizer = torch.optim.SGD(self.estimator.parameters(), lr = 0.0005, weight_decay=0.)
        # if t==1500:
        #     self.optimizer = torch.optim.SGD(self.estimator.parameters(), lr = 0.0001, weight_decay=0.)
        length = len(self.rewards)
        if length == 0:
            return 0
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0
        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx].to(device)
                y = self.rewards[idx]
                y_hat = self.estimator(c)
                self.current_loss = y*torch.log(1/y_hat) + (1 - y)* torch.log(1/(1 - y_hat))
                #self.current_loss = delta * delta
                self.optimizer.zero_grad() 
                    #gradient descent
                if self.t==1:
                    self.current_loss.backward(retain_graph=True)    
                else:
                    self.current_loss.backward()

                self.optimizer.step()
                batch_loss +=  self.current_loss.item()
                tot_loss +=  self.current_loss.item()
                cnt += 1
                if cnt >= 1000:
                    return tot_loss / 1000
            if batch_loss / length <= 1e-3:
                return batch_loss / length