import os
import sys
import torch

import numpy as np

from src.models import LanguageModelSpecification
from src.utils import set_seed
from src.args import parse_args
from src.task_generator import generate_linear_task, generate_circle_task, generate_moon_task
from src.datagenerator import generate_grid_data
from src.prompt import batch_prompt_generation

from tqdm import tqdm


def main():
    ### Environment settings
    args = parse_args()
    set_seed(args.seed)
    
    torch.cuda.set_device(args.gpu_id)
    
    ### Data generation
    if args.task_mode == "linear_classification":
        data, labels = generate_linear_task(num_classes=args.num_classes, mode="train", num_samples=args.num_samples, precision=args.precision, randseed=args.seed)
    elif args.task_mode == "circle_classification":
        data, labels = generate_circle_task(num_samples=args.num_samples, noise=0.03, mode="train", precision=args.precision, randseed=args.seed)
    elif args.task_mode == "moon_classification":
        data, labels = generate_moon_task(num_samples=args.num_samples, mode="train", precision=args.precision, randseed=args.seed)
    else:
        raise ValueError("Invalid task mode. The task modes defined in this repository are: ['binary_classification', 'circle_classification', 'moon_classification'.]")
    
    queries = generate_grid_data(data)
    
    ### Model settings
    model = LanguageModelSpecification(model_name=args.model_name, max_seq_len=512, max_batch_size=1, temperature=0.2)
    predictions = []
    prompts = batch_prompt_generation(in_context_data=list(data),
                                      in_context_labels=list(labels),
                                      queries=queries,
                                      prompt_mode=args.prompt_mode,
                                      algorithm=args.ml_alg)
    
    for prompt in tqdm(prompts):
        
        output = model.response_generation(prompt=prompt)
        try:
            if output[0][1] == 'B':
                res = 0
            elif output[0][1] == "T":
                res = 1
            else:
                res = int(output[0][1])
        except:
            print(output[0])
            sys.exit()
        predictions.append(res)
    
    res_dict = {
        "train_data": data,
        "train_labels": labels,
        "query_data": queries,
        "predictions": predictions
    }
    
    # save results
    save_root = os.path.join(os.getcwd(), "data_records", "pred_results")
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    
    filename = f"{args.model_name}_{args.prompt_mode}_binary_{args.task_mode}_{args.exp_name}_preds.npy"
    np.save(os.path.join(save_root, filename), res_dict)
    
    
if __name__ == "__main__":
    main()