# -*- coding: utf-8 -*-
"""
Created on Mon Jan  6 12:33:33 2025

@author: baran
"""
import numpy as np
from optimal import opt_eval
from final import final_eval
from prompt_maker import input_maker
from sum_call import get_summary

def get_optimal_super_arm_reward(deployments,prompt,task):
    
    #call Azure API to get optimal reward by trying out all combinations

    return opt_eval(deployments, prompt,task)

def get_reward(deployment,cat,prompt,task):
    
    #call Azure API to get final reward
    
    return final_eval(deployment, cat, prompt,task)
    

def get_regret(deployments,prompt,task,selected):
    #return get_optimal_super_arm_reward(deployments,prompt,task)-reward
    return opt_eval(deployments, prompt,task,selected)
# def sigmoid(x):
#     return np.where(x >= 0,
#                     1 / (1 + np.exp(-x)),
#                     np.exp(x) / (1 + np.exp(x)))

# def oracle(budget, g_list):
#     return np.argsort(g_list)[-budget:]


        
    
    
    
import scipy as sp
import torch
import torch.nn as nn
import torch.optim as optim
from backpack import backpack, extend
from backpack.extensions import BatchGrad
import argparse
from embed import get_context
class Network(nn.Module):
    def __init__(self, dim = 100, hidden_size=100):
        super(Network, self).__init__()
        #nn.Module.__init__(self)
        #self.fc1 = nn.Linear(dim, hidden_size)
        #self.activate = nn.ReLU()
        #self.fc2 = nn.Linear(hidden_size, 1)
        self.model = nn.Sequential(nn.Linear(dim, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1))
        #self.model = nn.Sequential(nn.Linear(dim, dim),nn.ReLU(),nn.Linear(dim, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1))
    #def forward(self, x):

    #    return self.fc2(self.activate(self.fc1(x)))


class NeuralUCBDiag:
    def __init__(self, style, dim, lamdba=1, nu=1, hidden=100):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device =='cuda':
            self.func = extend(Network(dim, hidden_size=hidden).model.cuda())
        else:
            self.func = Network(dim, hidden_size=hidden)
        self.context_list = []
        self.reward = []
        self.lamdba = lamdba
        #for p in self.func.parameters():
          #print(p.size())
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device =='cuda':
            self.U = lamdba * torch.ones((self.total_param,)).cuda()
            self.U_random = lamdba * torch.ones((self.total_param,)).cuda()
        else:
            self.U = lamdba * torch.ones((self.total_param,))
            self.U_random = lamdba * torch.ones((self.total_param,))
        self.nu = nu
        self.num_rounds = 1000
        self.style = style
        
        if self.device =='cuda':
            self.loss_func = extend(nn.MSELoss().cuda())
        else:
            self.loss_func = extend(nn.MSELoss())
        self.len = 0

    def update_params(self,g_list):
        #print(g_list.size())
        for g in g_list:
            self.U += g * g
        return 0
    
    def selection(self,context):
        tensor = torch.from_numpy(np.array(context)).float().cuda()
        mu = self.func(tensor)
        g_list = []
        #sum_mu = torch.sum(mu)
        #with backpack(BatchGrad()):
        #    sum_mu.backward()
        sampled = []
        for fx in mu:
            self.func.zero_grad()
            fx.backward(retain_graph=True)
            g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
            g_list.append(g)
            sigma2 = self.lamdba * self.nu * g * g / self.U
            sigma = torch.sqrt(torch.sum(sigma2))
            sample_r = fx.item() + 0.5* sigma.item()   
            sampled.append(sample_r)
        arm = np.argmax(sampled)
        print(arm)
        print(sampled)
        self.U += g_list[arm] * g_list[arm]
        return arm
            # g_list = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            # sigma = torch.sqrt(torch.sum(self.lamdba * self.nu * g_list * g_list / self.U, dim=1))
            # if self.style == 'ts':
            #     sample_r = torch.normal(mu.view(-1), 0.01*sigma.view(-1))
            # elif self.style == 'ucb':
            #     sample_r = mu.view(-1) + 0.1*sigma.view(-1)
    
    def train(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        self.reward.append(reward)
        optimizer = optim.SGD(self.func.parameters(), lr=1e-2, weight_decay=self.lamdba)
        length = len(self.reward)
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0
        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]
                optimizer.zero_grad()
                if self.device =='cuda':
                    delta = self.func(c.cuda()) - torch.tensor(float(r))
                else:
                    delta = self.func(c) - torch.tensor(float(r))
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= 3:
                    return tot_loss / 3
            if batch_loss / length <= 1e-3:
                return batch_loss / length

# torch.set_num_threads(6)
# torch.set_num_interop_threads(6)
parser = argparse.ArgumentParser(description='NeuralUCB')


parser.add_argument('--size', default=100, type=int, help='number of rounds')
#parser.add_argument('--super_arm_size', default=2, type=int, help='super arm size')
#parser.add_argument('--dataset', default='movielens', metavar='DATASET')
#parser.add_argument('--shuffle', type=bool, default=0, metavar='1 / 0', help='shuffle the data set or not')
#parser.add_argument('--seed', type=int, default=0, help='random seed for shuffle, 0 for None')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularzation')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ucb', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=2)



args, unknown = parser.parse_known_args()
#use_seed = None if args.seed == 0 else args.seed
#b = Bandit_multi(args.dataset, is_shuffle=args.shuffle, seed=use_seed)
#bandit_info = '{}'.format(args.dataset)
#K = args.super_arm_size
no_tasks = args.number_tasks
class DynamicVariables:
    pass

dyn_vars = DynamicVariables()


#l = NeuralUCBDiag(args.style, b.dim, args.lamdba, args.nu, args.hidden)
#l_2 = NeuralUCBDiag(args.style, b.dim, args.lamdba, args.nu, args.hidden) 
num_rounds = args.size




regrets = []
summ = 0
rew = 0
rewards_list = []
total_reward = 0
input_reports = input_maker()[0:args.size]
summary_description_array = ["You are an expert diagnosis analyst. Use your knowledge base to summarize inputted smoking diagnosis.",
                             "Summarize the main points of the medical report of a patient, this summary will be used for research purposes only."] #assistant, gpt4
#diagnosis_description_array = ["You will take the role of an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.",
#                               "As an expert which was trained on medical diagnosis, you will predict the diagnosis based on medical reports. This task is for research purposes only.","As an expert trained for specializing in telecommunications questions, you will predict the diagnosis based on medical reports. This task is for research purposes only."]

# diagnosis_description_array = ["General use large language model which does not specialize in any task specifically.",
#                                "Large language model finetuned on medical reports and trained to do medical diagnosis for research purposes.","Large language model finetuned for answering multiple choice telecommunications questions.", "General use small language model which does not specialize in any task specifically."]

diagnosis_description_array = ["General use LLM which does not specialize in any task specifically.",
                               "LLM specializing for answering multiple choice telecommunications questions.", "LLM specializing on medical reports and trained to do medical diagnosis for research purposes."]

#diagnosis_description_array = ["You will take the role of an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.",
#                               "You will take the role of an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.","You will take the role of an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."]

prompts = ["You are to summarize an inputted medical report for diagnosis purposes, this summary will be used for research purposes only.","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."]
#deployments_1 = {"base" : ("gpt-35-turbo","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."), "finetune_med" : ("Med","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."),"finetune_tele" : ("Tele","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."), "small": ("SLM","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.")}
deployments_1 = {"base" : ("gpt-35-turbo","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."), "finetune_med" : ("Med","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."),"finetune_tele" : ("Tele","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.")}
deployments_0 = {"base" : ("gpt-35-turbo","You are to summarize an inputted medical report for diagnosis purposes, this summary will be used for research purposes only."), "assistants" : ("Assistant","You are to summarize an inputted medical report for diagnosis purposes, this summary will be used for research purposes only.")}

#deployments_1 = {"base" : ("gpt-35-turbo","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."), "finetune_med" : ("Med","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only."),"finetune_tele" : ("Tele","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.")}
deploy = [deployments_0,deployments_1]
cat = ''
rewards = 0
emb_size = 384
for i in range(no_tasks):
    # setattr(dyn_vars, f'l_{i}', NeuralUCBDiag(args.style, b.dim, args.lamdba, args.nu, args.hidden))
    setattr(dyn_vars, f'l_{i}', NeuralUCBDiag(args.style,emb_size, args.lamdba, args.nu, args.hidden))
#input_reports = input_maker()
# from gensim.models.doc2vec import Doc2Vec, TaggedDocument
input_reports = list(input_reports)
documents = summary_description_array+diagnosis_description_array+input_reports
#documents = input_reports
# tagged_data = [TaggedDocument(words=doc.split(), tags=[str(i)]) for i, doc in enumerate(documents)]
# model = Doc2Vec(vector_size=100,  # Dimensionality of the feature vectors
#             window=20,         # Context window size
#             min_count=2,      # Ignores words with total frequency lower than this
#             workers=4,        # Number of worker threads for training
#             epochs=40)        # Number of training epochs

# # Build vocabulary from tagged data
# model.build_vocab(tagged_data)

# # Train the model
# model.train(tagged_data, total_examples=model.corpus_count, epochs=model.epochs)
from sentence_transformers import SentenceTransformer
#inp_model = SentenceTransformer("all-MiniLM-L6-v2")
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

# emb_size = 100
# window = 20
# epoch = 40
#gpt4, finetuned medical, finetuned telecom
for t in range(num_rounds):
    for i in range(no_tasks):
        if i==0:
            #context = np.zeros(len(summary_description_array))
            context = []
            models = ["assistants","base"]
            prompt_to_model = input_reports[t]
            task = 'summary'
            for j in range(len(summary_description_array)):
                #context[j] = get_context(input_reports[t],summary_description_array[j])
                context.append(get_context(documents,t,i,j,len(summary_description_array),len(diagnosis_description_array),len(input_reports),inp_model))
        #context = get_context(task, description)
        else:
            #context = np.zeros(len(diagnosis_description_array))
            context = []
            #models = ["base","finetune_med","finetune_tele","small"]
            models = ["base","finetune_med","finetune_tele"]
            #prompt_to_model = input_reports[t]+get_summary(input_reports[t],cat)
            prompt_to_model = get_summary(input_reports[t],cat)
            prompt_to_model = prompt_to_model.replace('\n','')
            task = 'diagnosis'
            documents = documents+[prompt_to_model]
            for j in range(len(diagnosis_description_array)):
                context.append(get_context(documents,t,i,j,len(summary_description_array),len(diagnosis_description_array),len(input_reports),inp_model)) 
        l = getattr(dyn_vars, f'l_{i}')
        arm_select = l.selection(context)
        #deployment = (models[arm_select],prompts[i])
        #print(arm_select)
        #print(models[arm_select])
        if models[arm_select] == "finetune_med" or models[arm_select]=="finetune_tele":
            cat = "finetune"
            #print('11')
        else:    
            cat = models[arm_select]
        dep = deploy[i]
        #print(cat)
        # if i==0:
        #     prompt_to_model = input_reports[t]
        # else:
        #     prompt_to_model = get_summary(input_reports[t],cat)
        #prompt_to_model = [input_reports[t],get_summary(input_reports[t])]
        # if models[arm_select] == 'finetune':
        #     deployment = (dep[models[arm_select]],prompts[i])
        # else:
        #     deployment = []
        selected = models[arm_select]
        fin_prompt = prompt_to_model
        #print(selected)
        deployment = dep[selected]
        #print(deployment)
        #print(cat)
        #print(fin_prompt)
        #reward = get_reward(deployment,cat,fin_prompt,task)
        #print(reward)
        if i==no_tasks-1:
            reg,reward = get_regret(deployments_1,fin_prompt,task,selected) #fill arguments here
            rewards += int(reward)
            rewards_list.append(rewards)
            print(selected)
            print(reward)
            print(reg)
            print("Done")
            summ+= reg
            regrets.append(summ)
        else:
            reward = get_reward(deployment,cat,fin_prompt,task)

        #if t<num_rounds:
        loss = l.train(context[arm_select], reward)
        #else:
        #    if t%100 == 0:
        #        loss = l.train(context[arm_select], reward)
    if t % 5 == 0:
        print('{}: {:.3f}, {:.3f}, {:.3f}'.format(t, summ, rewards,loss))

import pickle
with open("regrets_seqgpt.pkl", "wb") as file:
    pickle.dump(regrets, file)
with open("rewards_seqgpt.pkl", "wb") as file:
    pickle.dump(rewards_list, file)