import warnings
from itertools import product

import torch

from evaluate import evaluate_setup, evaluate, eval_seq_languge
from models import load_generator, INITIAL_PROMPT
from train_peft import train_peft_model
from utils import parse_args, get_results_pd, find_current_run, save_results_pd
from templates import get_templates
from data import load_split_dataset, PEFTDataset, TensorDataset, ourPEFTDataset
from examples import get_examples
import numpy as np

try:
    import wandb
except ImportError:
    wandb = None

if __name__ == '__main__':
    args = parse_args()
    if args.enhancement_function in ['iterative_pgd_with_projection']:
        args.enhancement_batch_size = int(1e6)

    results = get_results_pd(args.save_dir, force_rerun=args.force_rerun)
    tensor_scores = torch.Tensor()
    tensor_loss_avg1 = torch.Tensor()
    tensor_loss_avg2 = torch.Tensor()
    tensor_loss_avg3 = torch.Tensor()
    tensor_pred_avg = torch.Tensor()

    bool_init_wandb = False
    run_num = 0
    for seed_i in range(args.number_of_seeds):
        args.seed = [seed_i]
        for model in args.models:
            for dataset, seed, prediction_method, selection_method in product(
                    args.dataset, args.seed, args.prediction_method, args.examples_selection_method):

                if args.training_num_shots == -1:
                    args.training_num_shots = args.num_shots[0]

                if selection_method == '0-shot':
                    num_shots_range = [0]
                else:
                    num_shots_range = [args.training_num_shots]

                for num_shots in num_shots_range:
                    if prediction_method in ["channel", "calibrate"]:
                        labels_loss = True
                        if not args.labels_loss:
                            warnings.warn(
                                f"Using {prediction_method} with labels_loss set to False is highly discouraged, "
                                f"setting to True.")
                    else:
                        labels_loss = args.labels_loss

                    templates = get_templates(dataset, num_shots, args, args.num_templates, args.templates_path,
                                              args.template_seed, pad_str='</s>')
                    if dataset in ['seq_language']:
                        templates = [templates[0]]
                    # templates = [templates[2]]
                    method_name = f"{prediction_method}_{labels_loss}"
                    config = {'dataset': dataset, 'model': model, 'seed': seed,
                              'example_selection_method': selection_method, 'n_shots': num_shots,
                              'prediction_method': method_name, 'batch_size': args.eval_batch_size,
                              'precision': args.precision,
                              'template_seed': args.template_seed,
                              }

                    scores = find_current_run(config=config, results=results)
                    avg_losses1 = []
                    avg_losses2 = []
                    avg_losses3 = []
                    avg_pred = []
                    if len(scores) == len(templates):
                        continue

                    if args.use_wandb and bool_init_wandb == False:
                        wandb.init(entity=args.wandb_entity, project=args.wandb_project, reinit=True, config=config)
                        bool_init_wandb = True

                    num_evaluated_templates = len(scores)
                    for template in templates[num_evaluated_templates:]:
                        run_num += 1
                        print('*' * 200)
                        print(seed_i)
                        print(model)
                        print(dataset)
                        print(num_shots)
                        print(template)
                        print(run_num)
                        print('*' * 200)

                        generator = load_generator(
                            model, cache_dir=args.cache_dir, precision=args.precision,
                            local_files_only=args.local_files_only, device_map=args.device_map,
                            args=args
                        )
                        if args.method_boost_type in ['IPT', 'LORA', 'prompt_tuning', 'prefix_tuning',
                                                      'our_prompt_tuning']:
                            # todo: create training and eval dataloaders
                            train, val, labels_mp = load_split_dataset(dataset, cache_dir=args.cache_dir, args=args,
                                                                       tokenizer=generator.tokenizer)
                            # list_len = []
                            # for i in range(len(train)):
                            #     list_len.append(len(generator.tokenizer.tokenize(train['input'][i])))
                            # print(np.mean(list_len[:1000]))

                            labels = list(labels_mp.values())
                            selected_examples = get_examples(dataset, train, selection_method, seed,
                                                             num_shots,
                                                             example_ids=args.example_ids,
                                                             examples_path=args.examples_path)
                            examples, example_ids = selected_examples["examples"], selected_examples["example_ids"]
                            few_shot_train = train.iloc[selected_examples["example_ids"]]
                            train_dataset = ourPEFTDataset(
                                [(few_shot_train.iloc[i]['input'].strip(), few_shot_train.iloc[i]['target'].strip()) for
                                 i in range(len(few_shot_train))],
                                generator.tokenizer, labels, template, bool_is_test=False)

                            if args.use_context_for_instruciton and args.method_boost_type in ['prompt_tuning',
                                                                                               'our_prompt_tuning',
                                                                                               'IPT']:
                                eval_dataset = TensorDataset(
                                    [x.strip() for x in val['input']],
                                    generator.tokenizer, labels, template,
                                    examples=examples,
                                    method=prediction_method,
                                    only_icl=True,
                                    args=args
                                )
                                ICL_text = eval_dataset.text_context
                                ICL_token_ids = eval_dataset.ICL_input_ids
                                ICL_mask = eval_dataset.ICL_mask
                                if args.method_boost_type in ['our_prompt_tuning', 'IPT'] and args.add_instruction:
                                    # use_random_instruction, num_of_learnable_format
                                    if args.use_random_instruction == 0:
                                        instruction_tokens_ids = generator.tokenizer.encode(
                                            INITIAL_PROMPT[args.dataset[0]])
                                        ICL_token_ids = instruction_tokens_ids + ICL_token_ids
                                        ICL_mask = [-1] * len(instruction_tokens_ids) + ICL_mask
                                    else:
                                        ICL_token_ids = [2] * args.num_of_learnable_format + ICL_token_ids
                                        ICL_mask = [-1] * args.num_of_learnable_format + ICL_mask

                                generator = load_generator(
                                    model, cache_dir=args.cache_dir, precision=args.precision,
                                    local_files_only=args.local_files_only,
                                    device_map=args.device_map,
                                    args=args,
                                    ICL_token_ids=ICL_token_ids,
                                    ICL_mask=ICL_mask,
                                    ICL_text=ICL_text
                                )
                            generator.model = train_peft_model(
                                generator.model, generator.tokenizer, train_dataset, args
                            )
                            # todo: evaluate model
                            if args.dataset[0] not in ['seq_language']:
                                evaluation_result = evaluate_setup(
                                    dataset=dataset, generator=generator, seed=seed,
                                    template=template,
                                    prediction_method=prediction_method,
                                    labels_loss=labels_loss,
                                    selection_method=selection_method,
                                    num_shots=args.num_shots[0],
                                    example_ids=args.example_ids,
                                    examples_path=args.examples_path,
                                    batch_size=args.eval_batch_size,
                                    cache_dir=args.cache_dir,
                                    args=args,
                                    few_shot_train=few_shot_train
                                )
                            else:
                                evaluation_result = eval_seq_languge(
                                    dataset=dataset, generator=generator, seed=seed,
                                    template=template,
                                    prediction_method=prediction_method,
                                    labels_loss=labels_loss,
                                    selection_method=selection_method,
                                    num_shots=args.num_shots[0],
                                    example_ids=args.example_ids,
                                    examples_path=args.examples_path,
                                    batch_size=args.eval_batch_size,
                                    cache_dir=args.cache_dir,
                                    args=args,
                                    few_shot_train=few_shot_train
                                )
                            del generator
                            torch.cuda.empty_cache()
                            # evaluation_result = evaluate(val, eval_dataset,labels, generator,example_ids=example_ids,
                            #                               prediction_method=prediction_method,labels_loss=labels_loss,
                            #                               batch_size=args.eval_batch_size,cache_dir=args.cache_dir,args=args)
                        else:
                            evaluation_result = evaluate_setup(dataset=dataset, generator=generator, seed=seed,
                                                               template=template,
                                                               prediction_method=prediction_method,
                                                               labels_loss=labels_loss,
                                                               selection_method=selection_method, num_shots=num_shots,
                                                               example_ids=args.example_ids,
                                                               examples_path=args.examples_path,
                                                               batch_size=args.eval_batch_size,
                                                               cache_dir=args.cache_dir,
                                                               args=args
                                                               )
                        score = evaluation_result["score"]
                        results = save_results_pd(res_df=results, run_config=config, score=score, template=template,
                                                  save_dir=args.save_dir,
                                                  )

                        scores.append(score)
                        avg_losses1.append(float(evaluation_result["avg_loss_true_label1"]))
                        avg_losses2.append(float(evaluation_result["avg_loss_true_label2"]))
                        avg_losses3.append(float(evaluation_result["avg_loss_true_label3"]))
                        avg_pred.append(float(evaluation_result["probs"]))
                    tensor_scores = torch.concat((tensor_scores, torch.Tensor(scores).view(1, -1)), dim=0)
                    tensor_loss_avg1 = torch.concat((tensor_loss_avg1, torch.Tensor(avg_losses1).view(1, -1)), dim=0)
                    tensor_loss_avg2 = torch.concat((tensor_loss_avg2, torch.Tensor(avg_losses2).view(1, -1)), dim=0)
                    tensor_loss_avg3 = torch.concat((tensor_loss_avg3, torch.Tensor(avg_losses3).view(1, -1)), dim=0)
                    tensor_pred_avg = torch.concat((tensor_pred_avg, torch.Tensor(avg_pred).view(1, -1)), dim=0)
    if args.use_wandb:
        wandb_dict = {}
        wandb_dict['mean_of_loss_avg1'] = tensor_loss_avg1.mean()
        wandb_dict['mean_of_loss_avg2'] = tensor_loss_avg2.mean()
        wandb_dict['mean_of_loss_avg3'] = tensor_loss_avg3.mean()
        wandb_dict['mean_of_scores'] = tensor_scores.mean()
        wandb_dict['std_of_scores'] = tensor_scores.std()
        wandb_dict['mean_of_format_std'] = tensor_scores.std(dim=0).mean()
        wandb_dict['mean_of_seeds_std'] = tensor_scores.std(dim=1).mean()
        wandb_dict['example_ids'] = evaluation_result["example_ids"]
        wandb_dict['templates'] = [template.toJSON() for template in templates]
        scores_mean = tensor_scores.mean(dim=0)
        for i in range(len(templates)):
            wandb_dict['score_template_{}'.format(i)] = scores_mean[i].item()
        # log to wandb only fully completed runs
        wandb.log(wandb_dict)
    print(wandb_dict)
    print('Mean ', tensor_scores.mean(dim=0))
    print(f'Report: {tensor_scores.mean()}')
    if tensor_scores.shape[0] > 1:
        print('STD: ', tensor_scores.std(dim=0))
        print(f'Report std: {tensor_scores.std()}')
    print(templates)
    print('*' * 100)
    print('all scores:')
    print(tensor_scores)
