import random
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import numpy as np
import json
import matplotlib.pyplot as plt
import argparse
from utils import fix_seed
import os
from tqdm import tqdm
import torch
import argparse
from torch import optim, nn, utils, Tensor
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaForCausalLM, LlamaModel, LlamaTokenizer, LlamaConfig
import numpy as np
from mpl_toolkits import mplot3d
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def parse_arguments():
    parser = argparse.ArgumentParser(description="Zero-shot-CoT")
    parser.add_argument(
        "--task", type=str, default="gsm8k",
        choices=["aqua", "gsm8k", "commonsensqa", "addsub", "multiarith", "strategyqa", "svamp", "singleeq", "coin_flip", "last_letters"], help="dataset used for experiment"
    ) 
    parser.add_argument("--over", type=int, default="0", help="used to generate unique basis") 
    parser.add_argument("--num_basis", type=int, default="8", help="number of basis used for experiment") 
    parser.add_argument("--max_ra_len", type=int, default=5, help="maximum number of reasoning chains") 
    parser.add_argument("--pred_file", type=str, default="log/gsm8k_zero_shot_cot.log", help="use the reasoning chains generated by zero-shot-cot.") 
    parser.add_argument("--demo_save_dir", type=str, default="demos/gsm8k_basis_8", help="where to save the contructed demonstrations")
    parser.add_argument("--emb_save_dir", type=str, default="embedding/gsm8k", help="where to save the embedding")
    parser.add_argument("--question_save_dir", type=str, default="question/gsm8k", help="where to save the question")
    parser.add_argument("--random_seed", type=int, default=192, help="random seed")
    parser.add_argument("--encoder", type=str, default="all-MiniLM-L6-v2", help="which encoder for embedding")
    parser.add_argument("--debug", type=bool, default=True, help="debug mode")
    args = parser.parse_args()
    return args


def main():
    args = parse_arguments()
    fix_seed(args.random_seed)
    device = torch.device('cuda:1')
    
    if args.encoder == "all-MiniLM-L6-v2":
        encoder = SentenceTransformer('/home/notebook/code/personal/S9052827/auto-cot/checkpoint/MiniLM-L6-v2')
    elif args.encoder == "all-mpnet-base-v2":
        encoder = SentenceTransformer('/home/notebook/code/personal/S9052827/auto-cot/checkpoint/all-mpnet-base-v2')
    elif args.encoder == "t5-base":
        encoder = SentenceTransformer('/home/notebook/code/personal/S9052827/auto-cot/checkpoint/t5/sentence-t5-base')
    elif args.encoder == "t5-large":
        encoder = SentenceTransformer('/home/notebook/code/personal/S9052827/auto-cot/checkpoint/t5/sentence-t5-large')
    elif args.encoder == "t5-xl":
        encoder = SentenceTransformer('/home/notebook/code/personal/S9052827/auto-cot/checkpoint/t5/sentence-t5-xl')
    elif args.encoder == "t5-xxl":
        encoder = SentenceTransformer('/home/notebook/code/personal/S9052827/auto-cot/checkpoint/t5/sentence-t5-xxl')
    elif args.encoder == "e5-small":
        e5_model_path = "/home/notebook/code/personal/S9052827/auto-cot/checkpoint/e5/e5-small" # input your checkpoint path
        tokenizer = AutoTokenizer.from_pretrained(e5_model_path)
        encoder = AutoModel.from_pretrained(e5_model_path)
    elif args.encoder == "e5-base":
        e5_model_path = "/home/notebook/code/personal/S9052827/auto-cot/checkpoint/e5/e5-base" # input your checkpoint path
        tokenizer = AutoTokenizer.from_pretrained(e5_model_path)
        encoder = AutoModel.from_pretrained(e5_model_path)
    elif args.encoder == "e5-large":
        e5_model_path = "/home/notebook/code/personal/S9052827/auto-cot/checkpoint/e5/e5-large" # input your checkpoint path
        tokenizer = AutoTokenizer.from_pretrained(e5_model_path)
        encoder = AutoModel.from_pretrained(e5_model_path)
    elif args.encoder == "llama-7b":
        llama_model_path = "/home/notebook/code/personal/S9052827/auto-cot/checkpoint/llama/7b" # input your checkpoint path
        tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
        encoder = LlamaForCausalLM.from_pretrained(llama_model_path)
    elif args.encoder == "llama-13b":
        llama_model_path = "/home/notebook/code/personal/S9052827/auto-cot/checkpoint/llama/13b" # input your checkpoint path
        tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
        encoder = LlamaForCausalLM.from_pretrained(llama_model_path)

   
    task = args.task
    pred_file = args.pred_file
    save_file = args.demo_save_dir
    max_ra_len = args.max_ra_len
    num_basis = args.num_basis #num_basis

    corpus = []
    question = []
    rationale = []
    gold_ans = []
    pred_ans = []

    with open(pred_file, "r", encoding="utf-8") as fp:
        with open(args.question_save_dir,"w", newline='\n') as f:
            answer_seg = ""
            for line in fp:
                if "Q: " in line:
                    c_question = line.strip()

                if "A: " in line:
                    answer_seg = line
                elif "Therefore" in line and "the answer" in line:
                    c_rationale = answer_seg

                elif answer_seg != "":
                    answer_seg += line
                if "pred_mode" in line:
                    c_pred_ans = line.split(":")[1].strip()
                if "GT :" in line:
                    c_gold_ans = line.split(":")[1].strip()

                    c_rationale = c_rationale.replace("A: Let's think step by step.", "Let's think step by step.")
                    f.write(c_question+'\n')
                    c_question = c_question + "\nA:"

                    corpus.append(c_question)
                    question.append(c_question)
                    rationale.append(c_rationale)
                    pred_ans.append(c_pred_ans)
                    if args.debug:
                        gold_ans.append(c_gold_ans)
                    answer_seg = ""
    
    
    if args.encoder == "llama-7b" or args.encoder == "llama-13b" or args.encoder == "llama-30b":
        encoder.to(device)
        question_embeddings = {}
        corpus_embeddings = []
    
        for q in tqdm(corpus):
            llama_inputs = tokenizer(q, return_tensors="pt")
            question_embeddings[q] = torch.mean(encoder(llama_inputs.input_ids.to(device), output_hidden_states=True).hidden_states[-1][0].detach(), dim=0)
            corpus_embeddings.append((question_embeddings[q]).cpu().numpy())
                    
        corpus_embeddings = np.array(corpus_embeddings)
        # norm
        corpus_embeddings = (corpus_embeddings.T-np.mean(corpus_embeddings, axis=1)).T
        
        # SVD
        U,sigma,VT = np.linalg.svd(corpus_embeddings)

        # find k Principal Component
        prompt_base = np.dot(U[:num_basis][:],corpus_embeddings)
    
    elif args.encoder == "e5-small" or args.encoder == "e5-base" or args.encoder == "e5-large":
        device = torch.device('cuda:0')
        encoder.to(device)
        question_embeddings = {}
        corpus_embeddings = []

        for q in tqdm(corpus):
            e5_inputs = tokenizer(q, return_tensors="pt")
            question_embeddings[q] = torch.mean(encoder(e5_inputs.input_ids.to(device), output_hidden_states=True).hidden_states[-1][0].detach(), dim=0)
            corpus_embeddings.append((question_embeddings[q]).cpu().numpy())
               
        corpus_embeddings = np.array(corpus_embeddings)
        
        # SVD
        U,sigma,VT = np.linalg.svd(corpus_embeddings)

        # find k Principal Component
        prompt_base = np.dot(U[:num_basis][:],corpus_embeddings) #num_basis
    else:
        corpus_embeddings = encoder.encode(corpus) # shape:600*384,  t5-xl(xxl) shape 600*768
        
        # SVD
        U,sigma,VT = np.linalg.svd(corpus_embeddings)

        # find k Principal Component
        prompt_base = np.dot(U[:num_basis][:],corpus_embeddings) #num_basis

    dist = np.dot(corpus_embeddings, prompt_base.T)# （dataset_num, num_basis）
    
    demos = []
  
    cur_question = []

    for i in range(num_basis): #num_basis
        print("Basis ", i+1)
        tmp = list(map(list, zip(range(len(dist[:,i])), dist[:,i])))
        top_max_dist = sorted(tmp, key=lambda x: x[1], reverse=True) # For the i-th basis, sort according to the distance from the basis
        for element in top_max_dist:
            check = 0
            max_idx = element[0]
            c_rationale = rationale[max_idx].strip() 
            c_pred_ans = pred_ans[max_idx].strip() 
            if len(question[max_idx].strip().split()) <= 60 and len(c_rationale.replace("\n\n", "\n").split("\n")) <= max_ra_len \
                and c_pred_ans != "" and c_rationale[-1] == ".":
                if args.task in ["gsm8k", "multiarith", "singleeq", "addsub", "svamp"]:
                    if not (c_pred_ans.strip() in c_rationale.split(".")[-2] or c_pred_ans.strip() in c_rationale.split()[-10:]):
                        continue 
                # use over
                if args.over:
                    for j in range(len(cur_question)):
                        if cur_question[j]==question[max_idx]:
                            check = 1
                    if check == 1:
                        continue
                    cur_question.append(question[max_idx])
                c_question = question[max_idx] 
                
                
                c_rationale = c_rationale.replace("\n\n", "\n").replace("\n", " ").strip() 
                c_rationale = " ".join(c_rationale.split())
                if args.debug:
                    c_gold_ans = gold_ans[max_idx]
                else:
                    c_gold_ans = None
                demo_element = {
                    "question": c_question,
                    "rationale": c_rationale,
                    "pred_ans": c_pred_ans,
                    "gold_ans": c_gold_ans,
                    } 
                demos.append(demo_element)
                print(c_question) 
                print(c_rationale)
                print(c_pred_ans)
                print(c_gold_ans)
                print("")
                break

    demos = {"demo": demos}

    with open(args.demo_save_dir, 'w', encoding="utf-8") as write_f:
        json.dump(demos, write_f, indent=4, ensure_ascii=False) 

    pca_model = PCA(n_components=3, random_state=args.random_seed) 
    transformed = pca_model.fit_transform(corpus_embeddings)
    centers = pca_model.transform(prompt_base) 

    fig = plt.figure()
    ax = plt.axes(projection='3d')

    # Data for three-dimensional scattered points
    ax.scatter3D(transformed[:,0], transformed[:,1], transformed[:,2], c='green', s=30)
    ax.scatter3D(centers[:,0], centers[:,1], centers[:,2], c='red', s=60, marker='*')
    plt.savefig(save_file+".png", dpi=600) 

if __name__ == "__main__":
    main()