import argparse
from src.api import *
from src.utils import *
from src.agent import *
from src.env import *
import os


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--llm', type=str, default='claude', choices=['chatgpt', 'deepseek', 'claude', 'grok', 'gemini', 'llama', 'qwen'])
    parser.add_argument('--model', type=str, default='claude-3-7-sonnet-20250219', help='deepseek-chat, gpt-4o grok-3-latest gemini-2.0-flash llama4-maverick qwen-plus claude-3-7-sonnet-20250219')
    parser.add_argument('--max_tokens', type=int, default=8192)
    parser.add_argument('--ds_api_key', type=str)
    parser.add_argument('--openai_api_key', type=str)
    parser.add_argument('--gemini_api_key', type=str)
    parser.add_argument('--grok_api_key', type=str)
    parser.add_argument('--claude_api_key', type=str)
    parser.add_argument('--llama_api_key', type=str)
    parser.add_argument('--qwen_api_key', type=str)
    parser.add_argument('--data_path', type=str, nargs='?',
                    help='path to the image dataset (datasets or datasets/ILSVRC/Data/CLS-LOC)')
    parser.add_argument('--benchmark', default='nasbench201', type=str, help='transbench101 / nasbench201 / nasbench101')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--timeout', default=1200, type=int, help='max time for each proxy to run')
    parser.add_argument('--device', default="cuda:0", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
    parser.add_argument('--repeats', default=5, type=int, nargs='?', help='times of calculating single training-free metric')
    parser.add_argument('--input_samples', default=16, type=int, nargs='?',
                    help='input batch size for training-free metric')
    parser.add_argument('--name', default='default', type=str)
    parser.add_argument('--dataset', default='cifar10', help='(cifar10/cifar100/ImageNet16)')
    parser.add_argument('--temperature', default=0.0, type=float)
    parser.add_argument('--episodes', default=100, type=int)
    parser.add_argument('--steps', default=50, type=int)
    parser.add_argument('--save', default=1, type=int)
    parser.add_argument('--prompt_path', default='./prompt', type=str)
    parser.add_argument('--history_len', default=5, type=int)
    parser.add_argument('--num_action', default=3, type=int, help='num of prompts')
    parser.add_argument('--num_hidden', default=256, type=int)
    parser.add_argument('--actor_lr', default=1e-3, type=float)
    parser.add_argument('--critic_lr', default=1e-2, type=float)
    parser.add_argument('--gamma', default=0.9, type=float)
    parser.add_argument('--log_dir', default='./logs/logs.txt', type=str)
    parser.add_argument('--output_path', default='./output', type=str)
    parser.add_argument('--task', default='object', type=str, help='(jigsaw/autoencoding/scene/object)')
    parser.add_argument('--structure', default='micro', help='(macro/micro)')
    parser.add_argument('--beta', type=float)
    parser.add_argument('--decay', type=float)
    parser.add_argument('--num_layers', type=int, default=2)
    args = parser.parse_args()
    api = ChatAPI(deepseek_api_key=args.ds_api_key, openai_api_key=args.openai_api_key, gemini_api_key=args.gemini_api_key, grok_api_key=args.grok_api_key, llama_api_key=args.llama_api_key, claude_api_key=args.claude_api_key, qwen_api_key=args.qwen_api_key)

    prompts = read_prompts(args.prompt_path, args)
    ac = ActorCritic(args.history_len * (args.num_action + 1), args.num_hidden, args.num_action, args.actor_lr, args.critic_lr, args.gamma, args.device, args.num_layers)
    env = StrategyGenEnv(args.num_action, args.history_len)

    logger = get_logger(args.log_dir)



    for episode in range(args.episodes):
        print('initialize')
        pop = get_new_pop(None, api, 0, prompts, args)
        state = env.reset()
        print('start evaluate')
        env.evaluate_strategy(pop, args)
        transition_dict = {'states': [], 'actions': [], 'rewards': [], 'next_states': []}
    
        for step in range(args.steps):

            action = ac.take_action(state)
            logger.info('Receving new pop')
            new_pop = get_new_pop(pop, api, action, prompts, args)
            logger.info('Has received new pop')
            next_state, reward = env.step(action, new_pop, args)
            print(reward)

            transition_dict['states'].append(state)
            transition_dict['actions'].append(action)
            transition_dict['rewards'].append(reward)
            transition_dict['next_states'].append(next_state)

            logger.info(f'Current action : {action}')
            logger.info(f'New pop correlation : {reward}')

            pop = topk(pop, new_pop, 5, args.decay)

            if step % args.save == 0:
                save_pop(pop, args.output_path, episode)

            mean = 0
            for ind in pop:
                mean += abs(ind['score'])
            mean /= len(pop)


            logger.info(f'Generation {step + 1} correlation : {mean}\n')

            state = next_state
            ac.update(transition_dict)


        save_pop(pop, args.output_path, episode)

if __name__ == '__main__':
    main()