from dataclasses import dataclass
import fire
import os
import torch
import torch.utils
import json
from typing import Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from datetime import datetime
import re
from typing import Iterable

import datasets

from torch.nn import functional as F
from fishfarm.tasks.evalplus import load_dataset
import fishfarm
from fishfarm.tasks.base import Task, TaskResult

import sys
sys.path.append('.')
from utils import get_vllm_model, eval_model, forward, backward, load_hf_params_to_vllm, load_base_params


SYSTEM_MSG = """
# Analyze the given question and classify it into one of four categories: 'code', 'math', 'reasoning' or 'other'. Follow these guidelines:

1. Code: Questions asking for programming solutions, functions, algorithms. Often includes specific programming terms, language syntax, or data structures.
2. Math: Questions involving mathematical calculations, formulas, statistics. Often includes numbers, equations, or references to mathematical operations.
3. Reasoning: Questions requiring logical thinking, application of scientific knowledge, or critical analysis of information. Often presents statements that need evaluation based on general understanding. 
4. Other: Questions not clearly fit into above categories.
 
Instructions:
- Consider the primary focus, skills, and knowledge required to answer the question.
- If a question spans multiple categories, choose the most dominant one.
- Provide your final classification within \\boxed{} notation. Example: \\boxed{reasoning}

Format your response as follows:
Classification: \\boxed{category}
"""

MODEL_ID = 'meta-llama/Meta-Llama-3-8B-Instruct'
DECOMPOSED_PARAM_FILE = 'llama3_decomposed_params.pt'


def mean(iterable: Iterable[float]) -> float:
    total, count = 0.0, 0
    for x in iterable:
        total += x
        count += 1
    return total / count

def extract_ans(text):
    match = re.search(r'\\boxed{([^}]*)}', text)
    if match:
        return match.group(1)
    else:
        return None

@dataclass
class ClsSample:
    question: str
    label: str


class ClsTask(Task):
    def __init__(
        self,
        samples,
        context_messages,
    ):
        self.samples = list(samples)
        self.context_messages = context_messages
    
    @property
    def num_samples(self) -> int:
        return len(self.samples)

    def evaluate(
        self,
        model,
        sample_ids,
    ):
        if sample_ids is None:
            sample_ids = range(len(self.samples))
        samples = [self.samples[sample_id] for sample_id in sample_ids]
        
        requests = []
        for sample in samples:    
            messages = list(self.context_messages)
            messages.append(fishfarm.Message(role="user", content=sample.question))
            requests.append(fishfarm.models.GenerationRequest(messages=messages))

        sample_details = []
        for sample, result in zip(samples, model.generate(requests)):
            output = result.generation
            prediction = extract_ans(output)

            sample_details.append(
                dict(
                    question=sample.question,
                    label=sample.label,
                    output=output,
                    prediction=prediction,
                    correct=sample.label == prediction,
                )
            )

        aggregate_metrics = {
            "acc": mean(
                float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0
                for sd in sample_details
            )
        }
        return TaskResult(aggregate_metrics=aggregate_metrics, sample_details=sample_details)
    
    
def get_evaluator(num_samples_per_task=400) -> Tuple:
    
    task_datasets = [
        datasets.load_dataset("gsm8k", "main", split='test'),
        load_dataset(source_dataset='mbpp'),
        datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split='test'),
    ]
    task_labels = [
        'math', 'code', 'reasoning'
    ]
    
    samples = []
    choices = ["A", "B", "C", "D", "E"]
    for dataset, label in zip(task_datasets, task_labels):
        counter = 0
        for sample in dataset:
            counter += 1
            if counter >= num_samples_per_task:
                break
            if label == 'math':
                samples.append(
                    ClsSample(
                        question=sample["question"],
                        label='math',
                ))
            elif label == 'code':
                samples.append(
                    ClsSample(
                        question=sample.instruction,
                        label='code',
                ))    
            else:
                question = sample["question"] + "\n"
                question += "Options:\n"
                options = []
                for opt in sample["choices"]['text']:
                    options.append(opt)
                for i, opt in enumerate(options):
                    question += "{}. {}\n".format(choices[i], opt)
                samples.append(
                    ClsSample(
                        question=question,
                        label='reasoning',
                ))      
    
    train_ix = range(0, len(samples), 2)
    valid_ix = range(1, len(samples), 2)
    
    res = []            
    for ix in [train_ix, valid_ix]:
        res.append(ClsTask(
            samples=[samples[i] for i in ix],
            context_messages=[
                fishfarm.Message("system", SYSTEM_MSG),
            ],
        ))            
    
    return tuple(res)


def get_prompt_and_answer(tokenizer, samples, ix):
    context_msg = {'role': 'system', 'content': SYSTEM_MSG}
    user_msg = {'role': 'user', 'content': samples[ix].question}
    prompt = tokenizer.apply_chat_template(
        conversation=[context_msg, user_msg],
        chat_template=fishfarm.chat_templates.LLAMA3,
        tokenize=False,
        add_generation_prompt=True,
    )
    answer = samples[ix].label
    return prompt, answer


def main(num_iters: int=10000,
         test_interval: int=10,
         lr: float=2e-3,
         seed: int=42,
         kl_ref_coeff: float=0.3,
         use_kl_loss: bool=True,
         case_num: int=1,
         init_val: float=0.1,
         max_grad_norm: float=1e-3,
         test_only: bool=False,
         use_wandb: bool=False,
         custom_prefix: str=None):         
    """Main function."""
    global CASE_NUM
    CASE_NUM = case_num
    print(f'Case {CASE_NUM}')

    if use_wandb:
        import wandb
        _ = wandb.init(
            project="...",
            name=f"...",
            config={
                "lr": lr,
                "seed": seed,
                "max_grad_norm": max_grad_norm,
                "custom_prefix": custom_prefix,
                "init_val": init_val,                
            },
        )

    now = datetime.now()
    datetime_str = now.strftime("%Y%m%d-%H%M%S")
    log_dir = f'results/dispatcher/{custom_prefix}/{datetime_str}'
    os.makedirs(log_dir, exist_ok=True)

    vllm_model = get_vllm_model()
    vllm_model.chat_template = fishfarm.chat_templates.LLAMA3
    
    train_eval, valid_eval = get_evaluator()

    if test_only:
        print('Model performance before any update:')
        eval_model(vllm_model, valid_eval)

    train_data = train_eval.samples
    gpu = torch.device('cuda:1')

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, device_map='cuda:1', torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    base_params = model.state_dict()

    original_model_params = {k: v.clone().detach().cpu() for k, v
                              in base_params.items() if 'mlp' in k}

    if not os.path.exists(DECOMPOSED_PARAM_FILE):
        print('Decomposed params not found. Decomposing...')
        decomposed_params = {}
        for k, v in base_params.items():
            if 'norm' not in k:
                print(k)
                U, S, V = torch.svd(v.to(torch.float32))
                decomposed_params[f'{k}.U'] = U
                decomposed_params[f'{k}.S'] = S
                decomposed_params[f'{k}.V'] = V
        torch.save(decomposed_params, DECOMPOSED_PARAM_FILE)
    else:
        print('Decomposed params found. Loading...')
        decomposed_params = torch.load(DECOMPOSED_PARAM_FILE)
    for k, v in decomposed_params.items():
        decomposed_params[k] = v.to(torch.bfloat16).to(gpu)

    learnable_params = {}
    num_params = 0
    for k, v in base_params.items():
        if 'mlp' in k:
            learnable_params[k] = torch.nn.Parameter(
                data=(
                    torch.randn(
                        min(v.shape), device=gpu, dtype=torch.bfloat16,
                    ) * 0.01 + init_val
                ), requires_grad=True,
            )
            num_params += learnable_params[k].numel()
    print(f'#params={num_params}')
    learnable_params_list = list(learnable_params.values())
    optimizer = torch.optim.Adam(learnable_params_list, lr=lr)

    model.eval()
    for k in learnable_params:
        model.get_parameter(k).requires_grad_(True)

    train_size = len(train_data)
    best_val_acc = 0.
    for i in range(num_iters):
        
        
        batch_ix = range(train_size)
        prompts = [
            get_prompt_and_answer(tokenizer, train_data, i)[0] for i in batch_ix
        ]

        new_params = forward(
            model, base_params, decomposed_params, learnable_params)
        load_hf_params_to_vllm(new_params, vllm_model.llm)
        res = eval_model(vllm_model, train_eval, batch_ix)
        rewards = [1. if x['correct'] else -1. for x in res.sample_details]

        print('Computing reference log probs...')
        if use_kl_loss:
            ref_log_probs_list = []
            with torch.no_grad():
                load_base_params(
                    model=model,
                    base_params=original_model_params)
                for j, prompt in enumerate(prompts):
                    input_ids = tokenizer(
                        prompt, return_tensors='pt').input_ids.to(gpu)
                    prompt_length = input_ids.shape[-1]
                    output_ids = tokenizer(
                        prompt + res.sample_details[j]['output'],
                        return_tensors='pt',
                    ).input_ids.to(gpu)
                    generated_ids = output_ids[:, prompt_length:]

                    outputs = model(output_ids)
                    logits = outputs.logits[:, prompt_length-1:-1]
                    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                    selected_log_probs = log_probs.gather(
                        2, generated_ids.unsqueeze(-1)).squeeze(-1)
                    ref_log_probs_list.append(log_probs.detach().cpu())

                new_params = forward(
                    model, base_params, decomposed_params, learnable_params)
        
        print('Computing the policy gradient...')
        for j, prompt in enumerate(prompts):
            input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(gpu)
            prompt_length = input_ids.shape[-1]
            output_ids = tokenizer(
                prompt + res.sample_details[j]['output'], return_tensors='pt',
            ).input_ids.to(gpu)
            generated_ids = output_ids[:, prompt_length:]

            outputs = model(output_ids)
            logits = outputs.logits[:, prompt_length-1:-1]
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
            selected_log_probs = log_probs.gather(
                2, generated_ids.unsqueeze(-1)).squeeze(-1)
            log_likelihood = selected_log_probs.sum(axis=-1)

            pg = -log_likelihood * rewards[j] 
            loss = pg
            if use_kl_loss:
                ref_log_probs = ref_log_probs_list[j].to(gpu)
                kl_div = F.kl_div(
                    input=log_probs,
                    target=ref_log_probs,
                    log_target=True,
                    reduction='sum')
                loss = loss + kl_ref_coeff*kl_div
                
            scaled_loss = loss / train_size
            scaled_loss.backward()
            log_dict = {
                'pg': pg.item(),
                'loss': loss.item(),
                }
            if use_kl_loss:
                log_dict['kl_div'] = kl_div.item()

        backward(model, base_params, decomposed_params, learnable_params)
        
        grad_mean = learnable_params_list[0].grad.mean().item()
        torch.nn.utils.clip_grad_norm_(learnable_params_list, max_grad_norm)
        grad_norm_mean = learnable_params_list[0].grad.mean().item()
        optimizer.step()
        optimizer.zero_grad()
        model.zero_grad()

        print(
            f'Iter {i}: PG={pg.item()}, ' +
            f'param={learnable_params_list[0].mean()}, ' +
            f'grad={grad_mean}'
        )

        if i % test_interval == 0:
            forward(model, base_params, decomposed_params, learnable_params)
            load_hf_params_to_vllm(model.state_dict(), vllm_model.llm)

            train_res = eval_model(vllm_model, train_eval)
            valid_res = eval_model(vllm_model, valid_eval)
            if valid_res.aggregate_metrics["acc"] > best_val_acc:
                best_val_acc = valid_res.aggregate_metrics["acc"]
                print('best_val_acc updated')
                torch.save(learnable_params, f'{log_dir}/learnable_params.pt')

            torch.save(learnable_params,
                       f'{log_dir}/learnable_params_latest.pt')
            
            log_dict.update({
                'iter': i,
                'best_val_acc': best_val_acc,
                'train_acc': train_res.aggregate_metrics["acc"],
                'valid_acc': valid_res.aggregate_metrics["acc"],
                'grad_mean': grad_mean,
                'grad_norm_mean': grad_norm_mean,
            })
            if use_wandb:
                wandb.log(log_dict)

            with open(f'{log_dir}/log.json', 'a') as f:
                json_data = json.dumps(log_dict, indent=4)
                f.write(json_data)
                f.write('\n')



if __name__ == '__main__':
    fire.Fire(main)
