import os
import copy
import time
import random
import datasets
import torch
import argparse
import numpy as np
import cma
from fastNLP import cache_results, Tester, DataSet, Instance
from transformers import (
    RobertaConfig,
    RobertaTokenizer,
    BertConfig,
    BertTokenizer,
    ElectraConfig,
    ElectraTokenizer,
    BartConfig,
    BartTokenizer,
    T5Config,
    T5Tokenizer,
    GPT2Config,
    GPT2Tokenizer,
    BartConfig as CPTConfig,
)
from models.modeling_roberta import RobertaForMaskedLM
from models.modeling_bart import BartForConditionalGeneration
from models.modeling_t5 import T5ForConditionalGeneration
from models.modeling_gpt2 import GPT2LMHeadModel
from models.modeling_bert import BertForMaskedLM
from models.modeling_electra import ElectraForMaskedLM
from models.modeling_cpt import CPTForMaskedLM
from utils import hinge_loss
from sklearn.metrics import f1_score
from datasets import load_from_disk

class ZOSA(torch.optim.Optimizer):
    def __init__(self, params, rho=0.05, epsilon=1e-3, m=4, lr=1e-3):
        defaults = dict(rho=rho, epsilon=epsilon, m=m, lr=lr)
        super(ZOSA, self).__init__(params, defaults)

    def step(self, closure):
        for group in self.param_groups:
            rho = group['rho']
            epsilon = group['epsilon']
            m = group['m']
            lr = group['lr']
            u_list = []
            for _ in range(m):
                torch.manual_seed(torch.randint(0, 10000, (1,)))
                u = [torch.randint(0, 2, p.size(), device=p.device, dtype=p.dtype) * 2 - 1 for p in group['params']]
                u_list.append(u)
            
            l0 = closure()
            
            li_list = []
            for u in u_list:
                for p, u_p in zip(group['params'], u):
                    p.data.add_(epsilon * u_p)
                li = closure()
                for p, u_p in zip(group['params'], u):
                    p.data.add_(-epsilon * u_p)
                li_list.append(li)
            
            g_t = [torch.zeros_like(p) for p in group['params']]
            for u, li in zip(u_list, li_list):
                for g, u_p in zip(g_t, u):
                    g.add_((li - l0) * u_p / (epsilon * m))
            
            li_tensor = torch.tensor(li_list, dtype=torch.float32)
            sigma_t = torch.std(li_tensor, unbiased=False)

            epsilon_sam = [rho * g / sigma_t for g in g_t] if sigma_t > 0 else [torch.zeros_like(g) for g in g_t]

            for p, eps in zip(group['params'], epsilon_sam):
                p.data.add_(eps)
            
            l_pert = closure()
            
            u_pert_list = []
            for _ in range(m):
                torch.manual_seed(torch.randint(0, 10000, (1,)))
                u_pert = [torch.randint(0, 2, p.size(), device=p.device, dtype=p.dtype) * 2 - 1 for p in group['params']]
                u_pert_list.append(u_pert)
            
            li_pert_list = []
            for u_pert in u_pert_list:
                for p, u_p in zip(group['params'], u_pert):
                    p.data.add_(epsilon * u_p)
                li_pert = closure()
                for p, u_p in zip(group['params'], u_pert):
                    p.data.add_(-epsilon * u_p)
                li_pert_list.append(li_pert)
            
            g_pert = [torch.zeros_like(p) for p in group['params']]
            for u_pert, li_pert in zip(u_pert_list, li_pert_list):
                for g, u_p in zip(g_pert, u_pert):
                    g.add_((li_pert - l_pert) * u_p / (epsilon * m))
            
            for p, eps in zip(group['params'], epsilon_sam):
                p.data.add_(-eps)
            
            li_pert_tensor = torch.tensor(li_pert_list, dtype=torch.float32)
            sigma_t_pert = torch.std(li_pert_tensor, unbiased=False)
            adaptive_lr = lr / sigma_t_pert if sigma_t_pert > 0 else lr

            for p, g in zip(group['params'], g_pert):
                p.data.add_(-adaptive_lr * g)
                # p.data.add_(-lr * g)
        
        return closure()


parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default='roberta-large',
                    choices=['roberta-base', 'roberta-large',
                             'bert-base-uncased', 'bert-large-uncased',
                             'google/electra-base-generator', 'google/electra-large-generator',
                             'facebook/bart-base', 'facebook/bart-large',
                             't5-small', 't5-base', 't5-large', 't5-3b',
                             'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl',
                             'fnlp/cpt-large'], type=str)
parser.add_argument("--task_name", default='sst2', type=str)
parser.add_argument("--n_prompt_tokens", default=50, type=int)
parser.add_argument("--intrinsic_dim", default=500, type=int)
parser.add_argument("--k_shot", default=16, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--budget", default=8000, type=int)
parser.add_argument("--popsize", default=20, type=int)
parser.add_argument("--bound", default=0, type=int)
parser.add_argument("--sigma", default=1, type=float)
parser.add_argument("--alpha", default=1, type=float)
parser.add_argument("--print_every", default=50, type=int)
parser.add_argument("--eval_every", default=100, type=int)
parser.add_argument("--device", default='cuda:0', type=str)
parser.add_argument("--alg", default='ZOSA', choices=['CMA', 'ZOSA'], type=str)
parser.add_argument("--random_proj", default='normal', type=str)
parser.add_argument("--seed", default=4, type=int)
parser.add_argument("--loss_type", default='ce', type=str)
parser.add_argument("--cat_or_add", default='add', type=str)
parser.add_argument("--parallel", action='store_true', help='Whether to allow parallel evaluation')
parser.add_argument("--inference_framework", default='pt', type=str)
parser.add_argument("--onnx_model_path", default=None, type=str)
parser.add_argument("--rho", default=0.05, type=float)
parser.add_argument("--epsilon", default=1e-3, type=float)
parser.add_argument("--m", default=4, type=int)
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--num_steps", default=20000, type=int)
args = parser.parse_args()

model_name = args.model_name
if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
    from dataloaders.dataloader_t5 import SST2Loader, AGNewsLoader, YelpPLoader, DBPediaLoader, RTELoader, MRPCLoader, SNLILoader
    from metrics.metrics_t5 import SST2Metric, AGNewsMetric, YelpPMetric, DBPediaMetric, RTEMetric, MRPCMetric, SNLIMetric
elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
    from dataloaders.dataloader_gpt import SST2Loader, AGNewsLoader, YelpPLoader, DBPediaLoader, RTELoader, MRPCLoader, SNLILoader
    from metrics.metrics_gpt import SST2Metric, AGNewsMetric, YelpPMetric, DBPediaMetric, RTEMetric, MRPCMetric, SNLIMetric
elif model_name in ['fnlp/cpt-large']:
    from dataloaders.dataloader_cpt import ChnSentLoader, AmazonLoader, THUCNewsLoader, BQLoader, CMNLILoader, CCPMLoader, TNewsLoader, OCNLILoader, LCQMCLoader, C3Loader
    from metrics.metrics_cpt import ChnSentMetric, AmazonMetric, THUCNewsMetric, BQMetric, CMNLIMetric, CCPMMetric, TNewsMetric, OCNLIMetric, LCQMCMetric, C3Metric
else:
    from dataloaders.dataloader import SST2Loader, AGNewsLoader, YelpPLoader, DBPediaLoader, RTELoader, MRPCLoader, SNLILoader
    from metrics.metrics import SST2Metric, AGNewsMetric, YelpPMetric, DBPediaMetric, RTEMetric, MRPCMetric, SNLIMetric

task_name = args.task_name
n_prompt_tokens = args.n_prompt_tokens
intrinsic_dim = args.intrinsic_dim
k_shot = args.k_shot
batch_size = args.batch_size
budget = args.budget
bound = args.bound
sigma = args.sigma
alpha = args.alpha
popsize = args.popsize if args.popsize > 0 else 4 + 3 * np.log(intrinsic_dim)
device = args.device
alg = args.alg
random_proj = args.random_proj
seed = args.seed
loss_type = args.loss_type
print_every = args.print_every
eval_every = args.eval_every
cat_or_add = args.cat_or_add
parallel = args.parallel
inference_framework = args.inference_framework
onnx_model_path = args.onnx_model_path

if inference_framework not in ['pt', 'ort']:
    raise ValueError(f'inference_framework only supports "pt", "ort", got `{inference_framework}` instead.')
if inference_framework == 'ort':
    assert onnx_model_path is not None, 'Path to onnx model is required, got None instead.'
    assert os.path.exists(onnx_model_path), f'Invalid onnx model path `{onnx_model_path}`'

if cat_or_add == 'add':
    init_prompt_path = None
else:
    init_prompt_path = './nli_base_prompt.pt'

if task_name in ['sst2', 'yelpp', 'rte', 'mrpc', 'chnsent', 'lcqmc', 'bq']:
    num_labels = 2
elif task_name in ['snli', 'cmnli', 'ocnli']:
    num_labels = 3
elif task_name in ['agnews', 'ccpm', 'c3']:
    num_labels = 4
elif task_name in ['amazon']:
    num_labels = 5
elif task_name in ['thucnews']:
    num_labels = 10
elif task_name in ['dbpedia', 'tnews']:
    num_labels = 14
else:
    raise ValueError

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

class LMForwardAPI:
    def __init__(self, model_name='roberta-large', n_prompt_tokens=50, task_name='sst2',
                 loss_type='hinge', init_prompt_path=None, rho=0.05, epsilon=1e-3, m=4, lr=1e-3):
        if model_name in ['roberta-base', 'roberta-large']:
            self.config = RobertaConfig.from_pretrained('path/roberta-large')
            self.tokenizer = RobertaTokenizer.from_pretrained('path/roberta-large')
            self.model = RobertaForMaskedLM.from_pretrained(
                'path/roberta-large',
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
                inference_framework=inference_framework,
                onnx_model_path=onnx_model_path,
            )
            self.model.lm_head.bias = torch.nn.parameter.Parameter(torch.zeros(self.config.vocab_size))
        elif model_name in ['bert-base-uncased', 'bert-large-uncased']:
            self.config = BertConfig.from_pretrained(model_name)
            self.tokenizer = BertTokenizer.from_pretrained(model_name)
            self.model = BertForMaskedLM.from_pretrained(
                model_name,
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
            )
        elif model_name in ['google/electra-base-generator', 'google/electra-large-generator']:
            self.config = ElectraConfig.from_pretrained(model_name)
            self.tokenizer = ElectraTokenizer.from_pretrained(model_name)
            self.model = ElectraForMaskedLM.from_pretrained(
                model_name,
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
            )
        elif model_name in ['facebook/bart-base', 'facebook/bart-large']:
            self.config = BartConfig.from_pretrained(model_name)
            self.tokenizer = BartTokenizer.from_pretrained(model_name)
            self.model = BartForConditionalGeneration.from_pretrained(
                model_name,
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
            )
        elif model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
            self.config = T5Config.from_pretrained(model_name)
            self.tokenizer = T5Tokenizer.from_pretrained(model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(
                model_name,
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
            )
        elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
            self.config = GPT2Config.from_pretrained(model_name)
            self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
            self.model = GPT2LMHeadModel.from_pretrained(
                model_name,
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
            )
        elif model_name in ['fnlp/cpt-large']:
            self.config = CPTConfig.from_pretrained(model_name)
            self.tokenizer = BertTokenizer.from_pretrained(model_name)
            self.model = CPTForMaskedLM.from_pretrained(
                model_name,
                config=self.config,
                n_prompt_tokens=n_prompt_tokens,
            )
        else:
            raise NotImplementedError
        if inference_framework == 'ort':
            self.model.roberta = None
        if cat_or_add == 'cat':
            self.model.set_concat_prompt(True)
            if init_prompt_path is not None:
                print('Initialize prompt embedding from {}'.format(init_prompt_path))
                self.init_prompt = torch.load(init_prompt_path).weight.cpu().reshape(-1)
            else:
                print('Initial prompt embedding not found. Initialize to zero embedding.')
                self.init_prompt = torch.zeros(n_prompt_tokens * self.config.hidden_size)
            print('Shape of initial prompt embedding: {}'.format(self.init_prompt.shape))
        else:
            self.init_prompt = None
        self.model.to(device)
        self.model.eval()
        self.linear = torch.nn.Linear(intrinsic_dim, n_prompt_tokens * self.config.hidden_size, bias=False)
        self.linear.to(device)
        if random_proj == 'normal':
            if model_name in ['roberta-base', 'roberta-large']:
                embedding = self.model.roberta.get_input_embeddings().weight.clone().cpu()
            elif model_name in ['bert-base-uncased', 'bert-large-uncased']:
                embedding = self.model.bert.get_input_embeddings().weight.clone().cpu()
            elif model_name in ['google/electra-base-generator', 'google/electra-large-generator']:
                embedding = self.model.electra.get_input_embeddings().weight.clone().cpu()
            elif model_name in ['facebook/bart-base', 'facebook/bart-large', 'fnlp/cpt-large']:
                embedding = self.model.model.get_input_embeddings().weight.clone().cpu()
            elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
                embedding = self.model.transformer.get_input_embeddings().weight.clone().cpu()
            else:
                embedding = self.model.get_input_embeddings().weight.clone().cpu()
            mu_hat = np.mean(embedding.reshape(-1).detach().cpu().numpy())
            std_hat = np.std(embedding.reshape(-1).detach().cpu().numpy())
            mu = 0.0
            std = alpha * std_hat / (np.sqrt(intrinsic_dim) * sigma)
            print('[Embedding] mu: {} | std: {} [RandProj]  mu: {} | std: {}'.format(mu_hat, std_hat, mu, std))
            for p in self.linear.parameters():
                torch.nn.init.normal_(p, mu, std)
        self.prompt_embedding = torch.nn.Parameter(torch.zeros(n_prompt_tokens, self.config.hidden_size).to(device))
        self.optimizer = ZOSA([self.prompt_embedding], rho=rho, epsilon=epsilon, m=m, lr=lr)
        self.best_train_perf = 0.0
        self.best_dev_perf = 0.0
        # self.best_prompt = None
        self.best_prompt = np.zeros((n_prompt_tokens, self.config.hidden_size)) 
        self.num_call = 0
        self.print_every = print_every
        self.eval_every = eval_every
        self.loss_type = loss_type
        if task_name == 'sst2':
            self.metric = SST2Metric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'SST2Metric'
        elif task_name == 'agnews':
            self.metric = AGNewsMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'AGNewsMetric'
        elif task_name == 'yelpp':
            self.metric = YelpPMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'YelpPMetric'
        elif task_name == 'dbpedia':
            self.metric = DBPediaMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'DBPediaMetric'
        elif task_name == 'rte':
            self.metric = RTEMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'RTEMetric'
        elif task_name == 'mrpc':
            self.metric = MRPCMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'f1'
            self.metric_name = 'MRPCMetric'
        elif task_name == 'snli':
            self.metric = SNLIMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'SNLIMetric'
        elif task_name == 'chnsent':
            self.metric = ChnSentMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'ChnSentMetric'
        elif task_name == 'thucnews':
            self.metric = THUCNewsMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'THUCNewsMetric'
        elif task_name == 'lcqmc':
            self.metric = LCQMCMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'LCQMCMetric'
        elif task_name == 'cmnli':
            self.metric = CMNLIMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'CMNLIMetric'
        elif task_name == 'ocnli':
            self.metric = OCNLIMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'OCNLIMetric'
        elif task_name == 'amazon':
            self.metric = AmazonMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'AmazonMetric'
        elif task_name == 'bq':
            self.metric = BQMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'BQMetric'
        elif task_name == 'ccpm':
            self.metric = CCPMMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'CCPMMetric'
        elif task_name == 'tnews':
            self.metric = TNewsMetric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'TNewsMetric'
        elif task_name == 'c3':
            self.metric = C3Metric(target='labels', pred='logits', tokenizer=self.tokenizer)
            self.metric_key = 'acc'
            self.metric_name = 'C3Metric'
        else:
            raise NotImplementedError
        self.margin = self.metric.margin
        self.ce_loss = torch.nn.CrossEntropyLoss(reduction='mean')

    def calc_metric(self, logits, target):
        label_map = self.metric.label_map
        converted_target = target.clone()
        for key, val in label_map.items():
            converted_target[target == key] = val
        converted_target = converted_target.squeeze()
        interest_index = list(label_map.keys())
        logits = logits[:, interest_index]
        pred = logits.argmax(dim=-1)
        if self.metric_key == 'acc':
            perf = (pred == converted_target).sum() / len(target)
        elif self.metric_key == 'f1':
            perf = f1_score(converted_target.detach().cpu().numpy().tolist(),
                            pred.detach().cpu().numpy().tolist())
        else:
            raise KeyError(f'[Metric] Only support [acc, f1], got {self.metric_key} instead.')
        if self.loss_type == 'hinge':
            loss = hinge_loss(logits, converted_target, margin=self.margin, reduction='sum').item() / len(target)
        elif self.loss_type == 'ce':
            loss = self.ce_loss(logits, converted_target).item()
        elif self.loss_type == 'perf':
            loss = -1 * perf
        else:
            raise KeyError(f'[Loss] Only support [hinge, ce, perf], got {self.loss_type} instead.')
        return loss, perf

    def calc_loss(self):
        self.model.set_prompt_embedding(self.prompt_embedding)
        for k, v in train_data.items():
            train_data[k] = v.to(device)
        with torch.no_grad():
            if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
                logits = self.model(
                    input_ids=train_data['input_ids'],
                    attention_mask=train_data['attention_mask'],
                    decoder_input_ids=train_data['decoder_input_ids'],
                    decoder_attention_mask=train_data['decoder_attention_mask'],
                )['logits']
            elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
                logits = self.model(
                    input_ids=train_data['input_ids'],
                    attention_mask=train_data['attention_mask'],
                )['logits']
            else:
                logits = self.model(
                    input_ids=train_data['input_ids'],
                    attention_mask=train_data['attention_mask'],
                    mask_pos=train_data['mask_pos'],
                )['logits']
        loss, perf = self.calc_metric(logits, train_data['labels'])
        return loss

    def train(self, num_steps, tolerance=1e-4, patience=5, budget=10000):
        loss_history = []
        eval_count = 0
        consecutive_small_changes = 0

        for step in range(num_steps):
            def closure():
                nonlocal eval_count
                eval_count += 1
                if eval_count > budget:
                    raise StopIteration("Evaluation count exceeded budget")
                return self.calc_loss()

            try:
                loss = self.optimizer.step(closure)
            except StopIteration as e:
                print(e)
                break

            self.num_call += 1
            loss_history.append(loss)

            # Convergence check
            if len(loss_history) > 1:
                loss_change = abs(loss_history[-1] - loss_history[-2])
                if loss_change < tolerance:
                    consecutive_small_changes += 1
                else:
                    consecutive_small_changes = 0

                if consecutive_small_changes >= patience:
                    print(f"Stopped early due to {patience} consecutive loss changes less than {tolerance}")
                    break

            if self.num_call % self.print_every == 0:
                print(f"[# API Calls {self.num_call}] loss: {round(float(loss), 4)}")
            
            if self.num_call % self.eval_every == 0:
                print('********* Evaluated on dev set *********')
                for k, v in dev_data.items():
                    dev_data[k] = v.to(device)
                with torch.no_grad():
                    if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
                        logits = self.model(
                            input_ids=dev_data['input_ids'],
                            attention_mask=dev_data['attention_mask'],
                            decoder_input_ids=dev_data['decoder_input_ids'],
                            decoder_attention_mask=dev_data['decoder_attention_mask'],
                        )['logits']
                    elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
                        logits = self.model(
                            input_ids=dev_data['input_ids'],
                            attention_mask=dev_data['attention_mask'],
                        )['logits']
                    else:
                        logits = self.model(
                            input_ids=dev_data['input_ids'],
                            attention_mask=dev_data['attention_mask'],
                            mask_pos=dev_data['mask_pos'],
                        )['logits']
                dev_loss, dev_perf = self.calc_metric(logits, dev_data['labels'])
                if dev_perf > self.best_dev_perf:
                    self.best_dev_perf = dev_perf
                    self.best_prompt = self.prompt_embedding.detach().cpu().numpy()
                print(f'Dev loss: {round(float(dev_loss), 4)}. Dev perf: {round(float(dev_perf), 4)}. Best dev perf: {round(float(self.best_dev_perf), 4)}')
                print('********* Done *********')

        else:
            print("Reached maximum number of iterations")

        return loss_history

    def eval(self, prompt_embedding=None, test_data=None):
        self.num_call += 1
        if prompt_embedding is None:
            prompt_embedding = self.best_prompt
        if test_data is None:
            bsz = len(dev_data['input_ids'])
        else:
            bsz = batch_size
        tmp_prompt = copy.deepcopy(prompt_embedding)
        if isinstance(prompt_embedding, list):
            pe_list = []
            for pe in prompt_embedding:
                z = torch.tensor(pe, dtype=torch.float32, device=device)
                z = self.linear(z)
                if self.init_prompt is not None:
                    z = z + self.init_prompt
                pe_list.append(z.reshape(n_prompt_tokens, -1).repeat(bsz, 1, 1))
            prompt_embedding = torch.cat(pe_list)
            assert len(prompt_embedding) == len(train_data['input_ids'])
        elif isinstance(prompt_embedding, np.ndarray):
            prompt_embedding = torch.tensor(prompt_embedding, dtype=torch.float32, device=device)
            if prompt_embedding.dim() == 1:
                prompt_embedding = self.linear(prompt_embedding)
                if self.init_prompt is not None:
                    prompt_embedding = prompt_embedding + self.init_prompt
                prompt_embedding = prompt_embedding.reshape(n_prompt_tokens, -1)
            prompt_embedding = prompt_embedding.repeat(bsz, 1, 1)
        else:
            raise ValueError(
                f'[Prompt Embedding] Only support [list, numpy.ndarray], got `{type(prompt_embedding)}` instead.'
            )
        # Convert to torch.nn.Parameter to match expected type
        prompt_embedding_param = torch.nn.Parameter(prompt_embedding)
        self.model.set_prompt_embedding(prompt_embedding_param)
        if isinstance(test_data, DataSet):
            if prompt_embedding_param.shape[0] > bsz:
                raise ValueError('Provide a single prompt embedding for testing.')
            test_tester = Tester(data=test_data, model=self.model, metrics=self.metric, batch_size=batch_size,
                                 num_workers=1, device=device, use_tqdm=True)
            results = test_tester.test()
            test_acc = results[self.metric_name][self.metric_key]
            return test_acc
        else:
            for k, v in train_data.items():
                train_data[k] = v.to(device)
            with torch.no_grad():
                if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
                    logits = self.model(
                        input_ids=train_data['input_ids'],
                        attention_mask=train_data['attention_mask'],
                        decoder_input_ids=train_data['decoder_input_ids'],
                        decoder_attention_mask=train_data['decoder_attention_mask'],
                    )['logits']
                elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
                    logits = self.model(
                        input_ids=train_data['input_ids'],
                        attention_mask=train_data['attention_mask'],
                    )['logits']
                else:
                    logits = self.model(
                        input_ids=train_data['input_ids'],
                        attention_mask=train_data['attention_mask'],
                        mask_pos=train_data['mask_pos'],
                    )['logits']
            if parallel:
                all_losses, all_perfs = [], []
                for i in range(len(logits) // bsz):
                    tmp_logits = logits[i * bsz:i * bsz + bsz]
                    tmp_target = train_data['labels'][i * bsz:i * bsz + bsz]
                    tmp_loss, tmp_perf = self.calc_metric(tmp_logits, tmp_target)
                    all_losses.append(tmp_loss)
                    all_perfs.append(tmp_perf)
                loss = min(all_losses)
                best_sol = all_losses.index(loss)
                perf = all_perfs[best_sol]
                tmp_prompt = tmp_prompt[best_sol]
                prompt_embedding = pe_list[best_sol]
            else:
                loss, perf = self.calc_metric(logits, train_data['labels'])
            if perf > self.best_train_perf:
                self.best_train_perf = perf
            if self.num_call % self.print_every == 0:
                print(
                    '[# API Calls {}] loss: {}. Current perf: {}. Best perf so far: {}'.format(
                        self.num_call,
                        round(float(loss), 4),
                        round(float(perf), 4),
                        round(float(self.best_train_perf), 4)))
            if self.num_call % self.eval_every == 0:
                print('********* Evaluated on dev set *********')
                if parallel:
                    self.model.set_prompt_embedding(prompt_embedding)
                for k, v in dev_data.items():
                    dev_data[k] = v.to(device)
                with torch.no_grad():
                    if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
                        logits = self.model(
                            input_ids=dev_data['input_ids'],
                            attention_mask=dev_data['attention_mask'],
                            decoder_input_ids=dev_data['decoder_input_ids'],
                            decoder_attention_mask=dev_data['decoder_attention_mask'],
                        )['logits']
                    elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
                        logits = self.model(
                            input_ids=dev_data['input_ids'],
                            attention_mask=dev_data['attention_mask'],
                        )['logits']
                    else:
                        logits = self.model(
                            input_ids=dev_data['input_ids'],
                            attention_mask=dev_data['attention_mask'],
                            mask_pos=dev_data['mask_pos'],
                        )['logits']
                    dev_loss, dev_perf = self.calc_metric(logits, dev_data['labels'])
                    if dev_perf > self.best_dev_perf:
                        self.best_dev_perf = dev_perf
                        self.best_prompt = copy.deepcopy(tmp_prompt)
                    print('Dev loss: {}. Dev perf: {}. Best dev perf: {}'.format(
                        round(float(dev_loss), 4),
                        round(float(dev_perf), 4),
                        round(float(self.best_dev_perf), 4)))
                    print('********* Done *********')
            if parallel:
                return all_losses
            else:
                return loss

if model_name in ['roberta-base', 'roberta-large']:
    tokenizer = RobertaTokenizer.from_pretrained('path/roberta-large')
elif model_name in ['bert-base-uncased', 'bert-large-uncased', 'fnlp/cpt-large']:
    tokenizer = BertTokenizer.from_pretrained(model_name)
elif model_name in ['google/electra-base-generator', 'google/electra-large-generator']:
    tokenizer = ElectraTokenizer.from_pretrained(model_name)
elif model_name in ['facebook/bart-base', 'facebook/bart-large']:
    tokenizer = BartTokenizer.from_pretrained(model_name)
elif model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
    tokenizer = T5Tokenizer.from_pretrained(model_name)
elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
else:
    raise NotImplementedError

if task_name == 'sst2':
    label_token_to_id = {10999: 0, 10765: 1}
else:
    label_token_to_id = {}

cache_fn = f"caches/data_{model_name.replace('/', '-')}_{task_name}_{n_prompt_tokens}_{seed}.pt"
if model_name not in ['fnlp/cpt-large']:
    DataLoader = {
        'sst2': SST2Loader,
        'agnews': AGNewsLoader,
        'yelpp': YelpPLoader,
        'dbpedia': DBPediaLoader,
        'rte': RTELoader,
        'mrpc': MRPCLoader,
        'snli': SNLILoader,
    }
else:
    DataLoader = {
        'chnsent': ChnSentLoader,
        'thucnews': THUCNewsLoader,
        'lcqmc': LCQMCLoader,
        'cmnli': CMNLILoader,
        'ocnli': OCNLILoader,
        'amazon': AmazonLoader,
        'bq': BQLoader,
        'ccpm': CCPMLoader,
        'tnews': TNewsLoader,
        'c3': C3Loader,
    }

@cache_results(cache_fn, _refresh=True)
def get_data(task_name, tokenizer, splits):
    if task_name in ['agnews', 'yelpp', 'dbpedia', 'snli']:
        splits = ['train', 'test']
    else:
        splits = ['train', 'validation']
    data_bundle = DataLoader[task_name](tokenizer=tokenizer, n_prompt_tokens=n_prompt_tokens).my_load(splits)
    return data_bundle

def construct_true_few_shot_data(train_data, k_shot, task_name):
    train_label_count = {}
    dev_label_count = {}
    new_train_data = DataSet()
    new_dev_data = DataSet()
    all_indices = list(range(len(train_data)))
    np.random.shuffle(all_indices)

    for index in all_indices:
        instance = train_data[index]
        labels = instance['labels']
        
        if isinstance(labels, list):
            if not labels:
                continue
            label_token = labels[0]
            label = label_token_to_id.get(label_token, -1)
            print(f"Instance {index}: Label token = {label_token}, Mapped label = {label}")
        else:
            label = labels
            if not isinstance(label, int):
                continue
            if task_name == 'agnews':
                if label not in [0, 1, 2, 3]:
                    continue
            else:
                if label not in [0, 1]:
                    continue
        
        if label < 0:
            continue

        if label not in train_label_count:
            train_label_count[label] = 0
        if label not in dev_label_count:
            dev_label_count[label] = 0

        new_instance = Instance(
            input_ids=instance['input_ids'],
            attention_mask=instance['attention_mask'],
            mask_pos=instance['mask_pos'],
            labels=label
        )

        if train_label_count[label] < k_shot:
            new_train_data.append(new_instance)
            train_label_count[label] += 1
        elif dev_label_count[label] < k_shot:
            new_dev_data.append(new_instance)
            dev_label_count[label] += 1

    if len(new_train_data) == 0 or len(new_dev_data) == 0:
        raise ValueError("No instances added to train or dev data. Check label format and dataset.")

    if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
        new_train_data.set_input("input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask")
        new_dev_data.set_input("input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask")
    elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
        new_train_data.set_input("input_ids", "attention_mask")
        new_dev_data.set_input("input_ids", "attention_mask")
    else:
        new_train_data.set_input("input_ids", "attention_mask", "mask_pos")
        new_dev_data.set_input("input_ids", "attention_mask", "mask_pos")

    new_train_data.set_target("labels")
    new_dev_data.set_target("labels")

    print("new_train_data fields:", new_train_data.field_arrays.keys())
    print("new_dev_data fields:", new_dev_data.field_arrays.keys())
    
    return new_train_data, new_dev_data

data_bundle = get_data(task_name=task_name, tokenizer=tokenizer, splits=['train', 'validation', 'test'])
if task_name in ['agnews', 'yelpp', 'dbpedia', 'snli']:
    train_data, test_data = data_bundle.get_dataset('train'), data_bundle.get_dataset('test')
else:
    train_data, test_data = data_bundle.get_dataset('train'), data_bundle.get_dataset('validation')

print("train_data fields:", train_data.field_arrays.keys())
print("First sample:", train_data[0])

train_data, dev_data = construct_true_few_shot_data(train_data, k_shot, task_name)

if task_name == 'sst2':
    valid_instances = []
    for instance in test_data:
        labels = instance['labels']
        if isinstance(labels, list):
            if len(labels) == 1:
                label_token = labels[0]
                label = label_token_to_id.get(label_token, -1)
                if label >= 0:
                    instance['labels'] = label
                    valid_instances.append(instance)
        elif isinstance(labels, int):
            if labels in [0, 1]:
                valid_instances.append(instance)
    test_data = DataSet(valid_instances)
    test_data.set_target("labels")
    test_data.set_input("input_ids", "attention_mask", "mask_pos")

for ds in [train_data, dev_data, test_data]:
    if 'input_ids' in ds.field_arrays:
        ds.set_pad_val('input_ids', tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0)
    if 'attention_mask' in ds.field_arrays:
        ds.set_pad_val('attention_mask', 0)

print('# of train data: {}'.format(len(train_data)))
print('Example:')
print(train_data[0])
print('\n# of dev data: {}'.format(len(dev_data)))
print('Example:')
print(dev_data[0])
print('\n# of test data: {}'.format(len(test_data)))
print('Example:')
print(test_data[0])

if model_name in ['t5-small', 't5-base', 't5-large', 't5-3b']:
    train_data_dict = {}
    dev_data_dict = {}
    for field in ['input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'labels']:
        if field in train_data.field_arrays:
            if field == 'labels':
                train_data_dict[field] = torch.tensor([instance['labels'] for instance in train_data], dtype=torch.long)
            else:
                train_data_dict[field] = torch.tensor(train_data[field].get(list(range(len(train_data)))))
        if field in dev_data.field_arrays:
            if field == 'labels':
                dev_data_dict[field] = torch.tensor([instance['labels'] for instance in dev_data], dtype=torch.long)
            else:
                dev_data_dict[field] = torch.tensor(dev_data[field].get(list(range(len(dev_data)))))
    train_data = train_data_dict
    dev_data = dev_data_dict
elif model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
    train_data_dict = {}
    dev_data_dict = {}
    for field in ['input_ids', 'attention_mask', 'labels']:
        if field in train_data.field_arrays:
            if field == 'labels':
                train_data_dict[field] = torch.tensor([instance['labels'] for instance in train_data], dtype=torch.long)
            else:
                train_data_dict[field] = torch.tensor(train_data[field].get(list(range(len(train_data)))))
        if field in dev_data.field_arrays:
            if field == 'labels':
                dev_data_dict[field] = torch.tensor([instance['labels'] for instance in dev_data], dtype=torch.long)
            else:
                dev_data_dict[field] = torch.tensor(dev_data[field].get(list(range(len(dev_data)))))
    train_data = train_data_dict
    dev_data = dev_data_dict
else:
    train_data_dict = {}
    dev_data_dict = {}
    for field in ['input_ids', 'attention_mask', 'mask_pos', 'labels']:
        if field in train_data.field_arrays:
            if field == 'labels':
                train_data_dict[field] = torch.tensor([instance['labels'] for instance in train_data], dtype=torch.long)
            else:
                train_data_dict[field] = torch.tensor(train_data[field].get(list(range(len(train_data)))))
        if field in dev_data.field_arrays:
            if field == 'labels':
                dev_data_dict[field] = torch.tensor([instance['labels'] for instance in dev_data], dtype=torch.long)
            else:
                dev_data_dict[field] = torch.tensor(dev_data[field].get(list(range(len(dev_data)))))
    train_data = train_data_dict
    dev_data = dev_data_dict

model_forward_api = LMForwardAPI(
    model_name=model_name,
    n_prompt_tokens=n_prompt_tokens,
    task_name=task_name,
    loss_type=loss_type,
    init_prompt_path=init_prompt_path,
    rho=args.rho,
    epsilon=args.epsilon,
    m=args.m,
    lr=args.lr
)

if alg == 'CMA':
    cma_opts = {
        'seed': seed,
        'popsize': popsize,
        'maxiter': budget if parallel else budget // popsize,
        'verbose': -1,
    }
    if bound > 0:
        cma_opts['bounds'] = [-1 * bound, 1 * bound]
    es = cma.CMAEvolutionStrategy(intrinsic_dim * [0], sigma, inopts=cma_opts)
    print('Population Size: {}'.format(es.popsize))
    print('{} Evaluation.'.format('Parallel' if parallel else 'Serial'))
    if parallel:
        train_data['input_ids'] = train_data['input_ids'].repeat(es.popsize, 1)
        train_data['attention_mask'] = train_data['attention_mask'].repeat(es.popsize, 1)
        if 'mask_pos' in train_data:
            train_data['mask_pos'] = train_data['mask_pos'].repeat(es.popsize)
        train_data['labels'] = train_data['labels'].repeat(es.popsize)
    start_time = time.time()
    while not es.stop():
        solutions = es.ask()
        if parallel:
            fitnesses = model_forward_api.eval(solutions)
        else:
            fitnesses = [model_forward_api.eval(x) for x in solutions]
        es.tell(solutions, fitnesses)
    end_time = time.time()
    print('Done. Elapsed time: {} (mins)'.format((end_time - start_time) / 60))
elif alg == 'ZOSA':
    start_time = time.time()
    model_forward_api.train(args.num_steps, tolerance=1e-9, patience=2000, budget=args.budget)
    end_time = time.time()
    print('Done. Elapsed time: {} (mins)'.format((end_time - start_time) / 60))
else:
    raise ValueError(f"Unsupported algorithm: {alg}")

print('Evaluate on test data...')
test_acc = model_forward_api.eval(test_data=test_data)
print('Test acc: {}'.format(round(test_acc, 4)))