import torch as th
import numpy as np
import os 
from components.prompt_generator import get_prompt,get_traj_placeholder,get_step_placeholder,get_GRF_traj_placeholder,get_GRF_step_placeholder
import random 
import re
import time
from itertools import combinations
import google.generativeai as genai
import google.ai.generativelanguage as glm
from scipy.stats import kendalltau
import random

from openai import OpenAI

def find_smallest_indices_torch(lst, k):
    
    #smallest_values, smallest_indices = th.topk(lst, k, largest=False)
    #return smallest_indices.tolist()
    min_value, min_indices = th.min(lst, dim=0)
    min_indices = (lst == min_value).nonzero(as_tuple=True)[0]+1
    if k <= len(min_indices) :
        random_index = random.sample(min_indices.tolist(),k)
    else :
        random_index = random.sample(min_indices.tolist(),len(min_indices))
    #print(random_index)
    #max_index = th.max(min_indices)
    return random_index

def get_indices_of_sorted_values(tensor):
    sorted_values, sorted_indices = th.sort(tensor)
    
    return sorted_indices

def get_kendal(d1,d2) :
    _,l,n_r,n_a = d1.shape
    d_mean = list()
    for i in range(n_r) :
        mean_result = [th.mean(d1[0,:,i,:]),th.mean(d2[0,:,i,:])]
        d_mean.append(mean_result)
    taus = list()
    for pair in combinations(range(n_r),2) :
        tau,_ = kendalltau(d_mean[pair[0]],d_mean[pair[1]])
        taus.append(tau)
    mean_tau = sum(taus)/len(taus)
    return mean_tau
    
def get_preference(gpt_model='gemini',
                   n_pref = 1000,
                   n_epi=150,
                   traj_or_step='traj',
                   model_name='model_1',
                   step_n_pref = 50,
                   seq='1',
                   scenario='3m',
                   pref_per_epi=2,
                   n_agents=3,
                   replay_buffer_save_path="replay_buffer",
                   preference_save_path="get_llm_pref",
                   n_repeat = 1,
                   key='',
                   envs = 'sc2',
                   scalability=False,
                   automatic=True):
    save_dir ='{}/{}_{}_{}_{}'.format(replay_buffer_save_path,gpt_model,scenario,model_name,seq)
    if gpt_model == 'gemini':
        safety_settings=[
        {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_NONE",
        },
        ]
        API_KEY = key
        genai.configure(api_key=API_KEY)
        model = genai.GenerativeModel('gemini-pro',safety_settings)
    elif gpt_model == 'gpt-35-turbo' or gpt_model == 'gpt4' :
        client = OpenAI(api_key =  key)
    
    if traj_or_step == 'traj' :
        rand_list = list()
        '''
        consistency = list()
        for pair in combinations(range(1,n_epi),2) :
            rand_list.append((pair[0],pair[1]))
            pair_1 = th.load(save_dir+'/{}_int_ind_reward_all.pt'.format(pair[0]))
            pair_1_terminated = th.load(save_dir+'/{}_terminated.pt'.format(pair[0]))
            pair_1_end = th.where(pair_1_terminated==1)[1].item()
            pair_2 = th.load(save_dir+'/{}_int_ind_reward_all.pt'.format(pair[1]))
            pair_2_terminated = th.load(save_dir+'/{}_terminated.pt'.format(pair[1]))
            pair_2_end = th.where(pair_2_terminated==1)[1].item()
            
            kendal_result = get_kendal(pair_1[:pair_1_end].to('cpu'),pair_2[:pair_2_end].to('cpu'))
            consistency.append(kendal_result)
        '''

        #for pair in combinations(range(1,n_epi),2) :
            #rand_list.append((pair[0],pair[1]))
        for i in range(1,n_epi-1) :
            rand_list.append((i,i+1))
        #for pair in combinations(range(1,n_epi),2) :
            #rand_list.append((pair[0],pair[1]))
            
        #selected_pair = th.randperm(len(rand_list))
        selected_pair = rand_list
        
        #selected_pair = get_indices_of_sorted_values(th.tensor(consistency))

        total_cnt = 0 
        pref_result = th.zeros((2000,3))

        for traj_pair in selected_pair :
            retry = True
            if envs == 'gfootball' :
                placeholder = get_GRF_traj_placeholder(traj_pair[0],
                                                   traj_pair[1],
                                                   scenario=scenario,
                                                   save_traj_dir=save_dir)                
            else :
                placeholder = get_traj_placeholder(traj_pair[0],
                                                   traj_pair[1],
                                                   scenario=scenario,
                                                   save_traj_dir=save_dir)
    #            placeholder = get_traj_placeholder(rand_list[traj_pair][0],
    #                                               rand_list[traj_pair][1],
    #                                               scenario=scenario,
    #                                               save_traj_dir=save_dir)
            final_prompt = get_prompt(placeholder,scenario=scenario)
            '''
            print(rand_list[traj_pair][0])
            print(rand_list[traj_pair][1])
            print(final_prompt)
            '''
            
            results = list()
            for repeat in range(n_repeat) :
                retry = True
                
                while retry :
                    try :
                        if gpt_model == 'gemini':
                            response = model.generate_content(final_prompt)
                            result = response.text
                        elif gpt_model == 'gpt-35-turbo' :
                            chat_completion = client.chat.completions.create(
                                messages=[{
                                    "role":"user",
                                    "content":final_prompt, 
                                }],
                                model="gpt-3.5-turbo-0125",
                            )
                            result = chat_completion.choices[0].message.content
                        elif gpt_model == 'gpt4' :
                            chat_completion = client.chat.completions.create(
                                messages=[{
                                    "role":"user",
                                    "content":final_prompt, 
                                }],
                                model="gpt-4o-2024-05-13", #"gpt-4-turbo-2024-04-09",
                            )
                            result = chat_completion.choices[0].message.content
                        elif gpt_model == 'human' :
                            get_preference = True
                            print(final_prompt)
                            print('if 1 is better, input 1, else input 2, if same input 0')
                            if automatic :
                                t1_allies_health = float(placeholder['r_f_a_t_1'])
                                t2_allies_health = float(placeholder['r_f_a_t_2'])
                                t1_enemy_health = float(placeholder['r_f_e_t_1'])
                                t2_enemy_health = float(placeholder['r_f_e_t_2'])
                                t1_steps = int(placeholder['step_1'])
                                t2_steps = int(placeholder['step_2'])
                                
                                if t1_steps >= 110 and t2_steps < 110 :
                                    results.append(2)
                                    retry = False
                                elif t1_steps < 110 and t2_steps >= 110 :
                                    results.append(1)
                                    retry = False
                                elif t1_steps >= 110 and t2_steps >= 110 :
                                    results.append(0)
                                    retry = False                                    
                                else :
                                    if t1_enemy_health < t2_enemy_health :
                                        if t1_allies_health >= t2_allies_health :
                                            results.append(1)
                                        else :
                                            results.append(0)
                                    elif t1_enemy_health > t2_enemy_health :
                                        if t1_allies_health <= t2_allies_health :
                                            results.append(2)
                                        else :
                                            results.append(0)
                                    retry = False
                            else :
                                while get_preference :
                                    try :
                                        result = int(input())
                                        print(result)
                                        get_preference = False
                                    except Exception as e :
                                        print(e)
                                results.append(result)
                                retry = False
                        
                        if gpt_model != 'human' :
                            if result.find('#') != -1 :
                                numbers = re.findall('\d+',result[result.find('#'):])
                                if len(numbers) != 0 :
                                    results.append(int(numbers[0]))
                                    retry = False

                    except Exception as e :
                        print('error')
                        print(e)
                        #pref_result[total_cnt,3] = 1
                        time.sleep(10)
                        #total_cnt+=1
                        retry = False
                        
            if len(results) != 0 :
                #pref_result[total_cnt,0] = int(rand_list[traj_pair][0])
                #pref_result[total_cnt,1] = int(rand_list[traj_pair][1])
                pref_result[total_cnt,0] = int(traj_pair[0])
                pref_result[total_cnt,1] = int(traj_pair[1])
                pref = float((results.count(1)+0.5*results.count(0)) / len(results))
                pref_result[total_cnt,2] = float(pref)
                print(total_cnt)
                total_cnt += 1    
                
            if total_cnt % 10 == 0 :
                th.save(pref_result[:total_cnt],'{}/{}_{}_{}_{}_traj.pt'.format(preference_save_path,gpt_model,scenario,model_name,seq))
            if total_cnt >= n_pref :
                break
                
        pref_ref = pref_result[:total_cnt+1].clone()
        for idx,(t1,t2,pref) in enumerate(pref_ref[:-1]) :
            t1 = t1.item()
            t2 = t2.item()
            pref=pref.item()
            if pref == 1.0 :
                go = True
                cnt_2 = 1 
                while go :
                    if idx+cnt_2 < len(pref_ref) :
                        if pref_ref[cnt_2+idx,2].item() == 1.0 :
                            t2 = pref_ref[cnt_2+idx,1].item()
                            pref_result[total_cnt,0] = t1
                            pref_result[total_cnt,1] = t2
                            pref_result[total_cnt,2] = 1
                            total_cnt+=1
                            cnt_2+=1
                        else :
                            go=False
                    else :
                        go=False
            elif pref == 0.0 :
                go = True
                cnt_2 = 1 
                while go :
                    if idx-cnt_2 >= 0 :
                        if pref_ref[idx-cnt_2,2] == 0.0 :
                            t1 = pref_ref[idx-cnt_2,0].item()
                            pref_result[total_cnt,0] = t1
                            pref_result[total_cnt,1] = t2
                            pref_result[total_cnt,2] = 0
                            total_cnt+=1
                            cnt_2+=1
                        else :
                            go=False           
                    else :
                        go=False      
                        
    elif traj_or_step == 'step' : 
        
        pref_result = th.zeros((step_n_pref,2+n_agents))
        total_cnt = 0 
        epi = [ i for i in range(1,n_epi+1) ]
        selected_epi = th.randperm(len(epi))
        
        for episode_num in selected_epi :
            d = th.load(save_dir+'/{}_terminated.pt'.format(epi[episode_num]))
            kendall = th.load(save_dir+'/{}_kendalltau.pt'.format(epi[episode_num]))
            
            end_state = th.where(d==1)[1].item()
            if gpt_model == 'human' :
                total_candidates = list()
                act = th.load(save_dir+'/{}_actions.pt'.format(epi[episode_num]))[0]
                for s in range(0,end_state) :
                    attack = th.any(th.gt(act[s,:n_agents],6)).item()
                    if attack :
                        total_candidates.append(s)
                if len(total_candidates) != 0 :
                    if len(total_candidates) >= pref_per_epi :
                        chosen_step = random.sample(total_candidates,pref_per_epi)
                    else :
                        chosen_step = random.sample(total_candidates,len(total_candidates))                    
                else :
                    chosen_step = list()
            else :
                if scalability :
                    s = th.load(save_dir+'/{}_state.pt'.format(epi[episode_num]))[0]
                    for i in range(s.shape[0]) :
                        n_step = 4
                        health_sum = 0 
                        for agent in range(n_agents) :
                            health_sum += s[i,agent*n_step].item()
                        if health_sum == 0 :
                            end_state = i 
                            break                
                chosen_step = find_smallest_indices_torch(kendall[0,1:end_state,0], pref_per_epi)
            #chosen_step = find_smallest_indices_torch(kendall[0,1:end_state,0], pref_per_epi)
            
            for step in chosen_step :
                #step = step+1
                if envs == 'gfootball' :
                    placeholder = get_GRF_step_placeholder(epi[episode_num],step,scenario=scenario,save_traj_dir=save_dir)                    
                else :
                    placeholder = get_step_placeholder(epi[episode_num],step,scenario=scenario,save_traj_dir=save_dir)
                final_prompt = get_prompt(placeholder,scenario=scenario,prompt_type='step')
                #print(final_prompt)
                #response = model.generate_content(final_prompt)
                #result = response.text
                #print(result)
                for agent in range(n_agents) :
                    pref_result[total_cnt,2+agent] = 0
                    
                for repeat in range(n_repeat) :
                    try :
                        pref_result[total_cnt,0] = epi[episode_num]
                        pref_result[total_cnt,1] = step

                        #response = model.generate_content(final_prompt)
                        #result = response.text

                        if gpt_model == 'gemini':
                            response = model.generate_content(final_prompt)
                            result = response.text
                        elif gpt_model == 'gpt-35-turbo' :
                            chat_completion = client.chat.completions.create(
                                messages=[{
                                    "role":"user",
                                    "content":final_prompt, 
                                }],                            
                                model="gpt-3.5-turbo-0125", #model="gpt-3.5-turbo",
                            )
                            result = chat_completion.choices[0].message.content
                        elif gpt_model == 'gpt4' :
                            chat_completion = client.chat.completions.create(
                                messages=[{
                                    "role":"user",
                                    "content":final_prompt, 
                                }],
                                model="gpt-4o-2024-05-13", #"gpt-4-turbo-2024-04-09",
                            )
                            result = chat_completion.choices[0].message.content
                        elif gpt_model == 'human' :
                            get_preference = True
                            while get_preference :
                                print(final_prompt)
                                print('if 1 is better, input 1, else input 2')
                                result = list()
                                try :
                                    for agent_cnt in range(n_agents) :
                                        print('rank {} agent (1~{})'.format(agent_cnt+1,n_agents))
                                        rank_result = int(input())
                                        result.append(rank_result)
                                    get_preference = False
                                except Exception as e :
                                    print(e)

                        rank_cnt = 1
                        if gpt_model != 'human' :
                            while result.find('#') != -1 :
                                where_rank = result.find('#')
                                rank = result[where_rank+1] # 1
                                rank_start = result.find('{')
                                rank_end = result.find('}')
                                rank_result = result[rank_start+1:rank_end]   # {1,2}
                                numbers = re.findall('\d+',rank_result) # 1, 2 
                                result = result[rank_end+1:] 

                                if len(numbers) != 0 :
                                    for number in numbers :
                                        if int(number) <= n_agents and int(number) > 0 and int(rank) <= n_agents :
                                            pref_result[total_cnt,int(number)+1] = int(rank)

                                rank_cnt +=1 
                                if rank_cnt >= 10 :
                                    break
                            for agent in range(n_agents) :
                                if pref_result[total_cnt,2+agent] == 0 :
                                    pref_result[total_cnt,2+agent] = max(pref_result[total_cnt,2:])+1                    
                        else :
                            for agent in range(n_agents) :
                                pref_result[total_cnt,2+agent] = result[agent]

                    except Exception as e :
                        print('!!!!!! error !!!!!!!!!')
                        print(e)
                        time.sleep(10)
                        retry = False
                if pref_result[total_cnt,2] != 0 :
                    print(total_cnt)
                    total_cnt+=1 
                    
                if total_cnt >= step_n_pref :
                    break
                    
            if total_cnt % 10 == 0 :
                th.save(pref_result[:total_cnt],'{}/{}_{}_{}_{}_step.pt'.format(preference_save_path,gpt_model,scenario,model_name,seq))

            if total_cnt >= step_n_pref :
                    break
                    
    if traj_or_step == 'step' :
        th.save(pref_result[:total_cnt],'{}/{}_{}_{}_{}_step.pt'.format(preference_save_path,gpt_model,scenario,model_name,seq))
    else :
        th.save(pref_result[:total_cnt],'{}/{}_{}_{}_{}_traj.pt'.format(preference_save_path,gpt_model,scenario,model_name,seq))