from src.models.modeling_gpt2 import ExtendedGPT2Config, ExtendedGPT2LMHeadModel
from src.models.modeling_gpt_neox import ExtendedGPTNeoXConfig, ExtendedGPTNeoXForCausalLM
from src.models.common import CausalLMConfig, CausalLM
from src.trainer.mimic import MimicArguments, MimicTrainer
from src.trainer.finetune import FinetuneArguments, FinetuneTrainer
from src.trainer.downstream import DownstreamArguments, DownstreamTrainer
from src.datasets.downstream import load_tokenized_downstream_dataset
from src.metrics.metrics import classification_accuracy
from src.utils.manual_seed import manual_seed

from transformers import PreTrainedModel, AutoTokenizer, DataCollatorForLanguageModeling
from datasets import load_dataset
import torch

from typing import Optional, Union, Literal, Any
from dataclasses import dataclass, asdict, field
from time import sleep
from tqdm import tqdm
import argparse
import json
from functools import partialmethod

@dataclass
class ModelArguments:
    origin_model: Literal["gpt2", "pythia-1b"] = "gpt2"
    pretrained_path: Optional[str] = None
    feature_type: str = "fourier"
    coef_type: str = "standard"
    adaptive_shift: Optional[float] = None
    num_features: Union[int, list] = field(default_factory=lambda: 64)
    use_linear_attn: bool = True
    recurrence: bool = True

@dataclass
class DataArguments:
    dataset_name: str = "HuggingFaceTB/cosmopedia"
    dataset_subset: Optional[str] = "stanford"
    dataset_split: Optional[str] = "train"
    train_split_ratio: float = 0.8
    test_split_ratio: float = 0.2

def _convert_pretrained_to_state_dict(model_name, pretrained_path):
    state_dict = torch.load(f"{pretrained_path}/pytorch_model.bin", map_location="cpu")

    if model_name == "gpt2":
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("origin_model."):
                continue
            elif (key.startswith("linear_model.h.") or key.startswith("linear_model.ln_f.")
                or key.startswith("linear_model.wte.") or key.startswith("linear_model.wpe.")):
                new_key = key.replace("linear_model.", "transformer.")
            elif key.startswith("linear_model."):
                new_key = key.replace("linear_model.", "")
            new_state_dict[new_key] = value
    elif model_name.startswith("pythia"):
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("origin_model."):
                continue
            elif key.startswith("linear_model."):
                new_key = key.replace("linear_model.", "")
            new_state_dict[new_key] = value
    else:
        raise ValueError(f"Invalid model name: {model_name}")
    
    return new_state_dict

def finetune(finetune_args: FinetuneArguments, train_dataset, test_dataset, tokenizer,
             model: CausalLM, device: str):
    
    num_params = 0
    if finetune_args.train_only_head:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.lm_head.parameters():
            param.requires_grad = True
            num_params += param.numel()
    elif finetune_args.train_only_last_mlp:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.transformer.h[-1].mlp.parameters():
            param.requires_grad = True
            num_params += param.numel()
        for param in model.transformer.ln_f.parameters():
            param.requires_grad = True
            num_params += param.numel()
    else:
        for param in model.parameters():
            param.requires_grad = True
            num_params += param.numel()
    print(f"Number of learnable parameters: {num_params}")
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )
    trainer = FinetuneTrainer(
        model=model,
        args=finetune_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=data_collator,
    )

    if finetune_args.checkpoint_path is not None:
        _resume_from_checkpoint = f"{finetune_args.checkpoint_path}"
    else:
        _resume_from_checkpoint = False

    if finetune_args.do_train:
        trainer.train(resume_from_checkpoint=_resume_from_checkpoint)
    if finetune_args.do_eval:
        metrics = trainer.evaluate()
        print(metrics)

    trainer.save_state()
    trainer.save_model()

def train_to_mimic(mimic_args: MimicArguments, train_dataset, test_dataset, tokenizer, 
                   origin_model: CausalLM, linear_model: CausalLM, device: str):
    
    if mimic_args.skip:
        print("[Message] Skipping the training to mimic.")
        return
    
    for param in origin_model.parameters():
        param.requires_grad = False
    for param in linear_model.parameters():
        param.requires_grad = False
    for layer_idx, layer in enumerate(linear_model.transformer.h):
        if not hasattr(layer.attn, 'feature_net'):
            print(f"[Message] The feature network is not found in the attention layer {layer_idx}.")
            continue
        for param in layer.attn.feature_net.parameters():
            param.requires_grad = True

    optimizer_grouped_parameters = [
        {'params': [p for n, p in linear_model.named_parameters() if 'feature_net.coef' not in n], 
         'lr': mimic_args.learning_rate, 
         'name': 'others'},
        {'params': [p for n, p in linear_model.named_parameters() if 'feature_net.coef' in n], 
         'lr': mimic_args.coef_learning_rate if mimic_args.coef_learning_rate is not None else mimic_args.learning_rate,
         'name': 'feature_net.coef'},
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters, 
        weight_decay=mimic_args.weight_decay, 
        betas=(mimic_args.adam_beta1, mimic_args.adam_beta2),
        eps=mimic_args.adam_epsilon
    )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )
    trainer = MimicTrainer(
        origin_model=origin_model,
        linear_model=linear_model,
        args=mimic_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        optimizers=(optimizer, None),
        data_collator=data_collator,
    )

    if mimic_args.checkpoint_path is not None:
        _resume_from_checkpoint = f"{mimic_args.checkpoint_path}"
    else:
        _resume_from_checkpoint = False

    if mimic_args.do_train:
        trainer.train(resume_from_checkpoint=_resume_from_checkpoint)
    if mimic_args.do_eval:
        metrics = trainer.evaluate()
        print(metrics, flush=True)

    trainer.save_state()
    trainer.save_model()

def downstream_eval(task_name: str, downstream_args: DownstreamArguments, 
                    model: CausalLM, tokenizer, device):

    max_length = model.config.n_positions
    train_dataset, test_dataset = load_tokenized_downstream_dataset(task_name, tokenizer, max_length)

    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if 'feature_net.' not in n], 
         'lr': downstream_args.learning_rate, 
         'name': 'others'},
        {'params': [p for n, p in model.named_parameters() if 'feature_net.' in n], 
         'lr': 0.,
         'name': 'feature_net'},
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters, 
        weight_decay=downstream_args.weight_decay, 
        betas=(downstream_args.adam_beta1, downstream_args.adam_beta2),
        eps=downstream_args.adam_epsilon
    )

    trainer = DownstreamTrainer(
        model=model,
        task_name=task_name,
        args=downstream_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        optimizers=(optimizer, None),
        compute_metrics=classification_accuracy,
    )
    if downstream_args.do_train:
        trainer.train()
    if downstream_args.do_eval:
        metrics = trainer.evaluate()
        print(metrics, flush=True)

    trainer.save_state()
    trainer.save_model()

def train(data_args: DataArguments, model_args: ModelArguments, savedirname: str,
          origin_finetune_args: FinetuneArguments, mimic_args: MimicArguments, 
          linear_finetune_args: FinetuneArguments, downstream_args_list: dict[DownstreamArguments],
          device: str):
    model_name = model_args.origin_model
    if model_name == "gpt2":
        _model_name = "gpt2"
        ConfigClass = ExtendedGPT2Config
        ModelClass = ExtendedGPT2LMHeadModel
    elif model_name.startswith("pythia"):
        _model_name = f"EleutherAI/{model_name}-deduped"
        ConfigClass = ExtendedGPTNeoXConfig
        ModelClass = ExtendedGPTNeoXForCausalLM
    else:
        raise ValueError(f"Invalid model name: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model_config: CausalLMConfig = ConfigClass.from_pretrained(_model_name)
    model_config.use_linear_attn = False
    origin_model = ModelClass.from_pretrained(_model_name, config=model_config)
    origin_model.to(device)
    linear_config: CausalLMConfig = ConfigClass.from_pretrained(_model_name)
    model_args_dict = asdict(model_args)
    model_args_dict.pop("origin_model")
    linear_config.update(model_args_dict)
    linear_model = ModelClass(linear_config)
    if model_args.pretrained_path is not None:
        state_dict = _convert_pretrained_to_state_dict(model_name, model_args.pretrained_path)
        linear_model.load_state_dict(state_dict)
        print(f"Loaded the pretrained model from {model_args.pretrained_path}")
    else:
        linear_model = ModelClass(linear_config)
    linear_model.to(device)

    print("Number of parameters:")
    params = 0
    for p in origin_model.parameters():
        if p.requires_grad:
            params += p.numel()    
    print("\t Original model:", params)
    params = 0
    for p in linear_model.parameters():
        if p.requires_grad:
            params += p.numel()
    print("\t Linear model:", params)

    dataset_args = (data_args.dataset_name, )
    dataset_kwargs = {}
    if data_args.dataset_subset is not None:
        dataset_args += (data_args.dataset_subset, )
    if data_args.dataset_split is not None:
        dataset_kwargs["split"] = data_args.dataset_split
    dataset = load_dataset(*dataset_args, **dataset_kwargs, trust_remote_code=True)
    max_length = model_config.max_position_embeddings
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
    split_dataset = dataset.train_test_split(train_size=data_args.train_split_ratio, 
                                             test_size=data_args.test_split_ratio)
    del dataset
    train_dataset = split_dataset["train"]
    train_dataset = train_dataset.map(tokenize_function, batched=True)
    test_dataset = split_dataset["test"]
    test_dataset = test_dataset.map(tokenize_function, batched=True)
    
    print("=== Finetuning the original model ===")
    finetune(origin_finetune_args, train_dataset, test_dataset, 
             tokenizer, origin_model, device)

    print("=== Train to mimic ===")
    train_to_mimic(mimic_args, train_dataset, test_dataset,
                   tokenizer, origin_model, linear_model, device)

    print("Deleting the original model ...", end=" ", flush=True)
    del origin_model
    sleep(10)
    print("Done", flush=True)
    torch.cuda.empty_cache()

    torch.cuda.empty_cache()

    print("=== Finetuning the linear model ===")
    finetune(linear_finetune_args, train_dataset, test_dataset, 
             tokenizer, linear_model, device)
    
    del train_dataset, test_dataset
    torch.cuda.empty_cache()
    
    print("=== Downstream evaluation ===")
    linear_model.save_pretrained(f"{savedirname}/downstream/tmp")
    for task_name, downstream_args in downstream_args_list.items():
        print(f"--- Task: {task_name} ---")
        if not downstream_args.train_only_head:
            config = ConfigClass.from_pretrained(f"{savedirname}/downstream/tmp")
            linear_model = ModelClass.from_pretrained(f"{savedirname}/downstream/tmp", config=config)
            linear_model.to(device)
        downstream_eval(task_name, downstream_args, linear_model, tokenizer, device)

def preprocess_config(config: dict[str, dict[str, Any]]):
    model_name = config["model"].get("origin_model", "gpt2")
    def _get_learning_rate(_config: dict[str, Any]):
        if _config.get("learning_rate") is not None:
            return _config["learning_rate"]
        
        if _config.get("train_only_head", False) or _config.get("train_only_last_mlp", False):
            return 1e-3
        
        if model_name == "gpt2":
            return 5e-5
        elif model_name == "pythia-1b":
            return 5e-6
        else:
            raise ValueError(f"Unknown model name: {model_name}")
            
    config["origin_finetune"]["learning_rate"] = _get_learning_rate(config["origin_finetune"])
    config["linear_finetune"]["learning_rate"] = _get_learning_rate(config["linear_finetune"])
    for downstream_tasks in config["downstream"].keys():
        config["downstream"][downstream_tasks]["learning_rate"] = _get_learning_rate(config["downstream"]["common"] | config["downstream"][downstream_tasks])
        
    return config

def main(savedirname: str, config_path: str):
    manual_seed(42)

    with open(config_path, "r") as f:
        config: dict[str, dict[str, Any]] = json.load(f)
    config = preprocess_config(config)
    
    allow_tqdm = int(config["misc"].get("allow_tqdm", True))
    if not allow_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    config["misc"]["num_gpus"] = torch.cuda.device_count()
    
    config_downstream_common = config["downstream"].pop("common")

    model_args = ModelArguments(**config["model"])
    data_args = DataArguments(**config["data"])
    origin_finetune_args = FinetuneArguments(output_dir=f"{savedirname}/origin", 
                                             **config["origin_finetune"])
    mimic_args = MimicArguments(output_dir=f"{savedirname}/mimic", 
                                **config["mimic"])
    linear_finetune_args = FinetuneArguments(output_dir=f"{savedirname}/linear", 
                                             **config["linear_finetune"])
    downstream_args_dict = {task_name: DownstreamArguments(output_dir=f"{savedirname}/downstream/{task_name}", 
                                                           **(config["downstream"][task_name] | config_downstream_common))
                            for task_name in config["downstream"].keys()}

    loaded_config = {"model": asdict(model_args),
                     "data": asdict(data_args), 
                     "origin_finetune": origin_finetune_args.get_dict(),
                     "mimic": mimic_args.get_dict(),
                     "linear_finetune": linear_finetune_args.get_dict(),
                     "misc": config["misc"], 
                     "downstream": {task_name: downstream_args.get_dict() 
                                    for task_name, downstream_args in downstream_args_dict.items()}
                    }
    with open(f"{savedirname}/config.json", "w") as f:
        json.dump(loaded_config, f, indent=4)

    memory_history = config["misc"].get("memory_history", False)
    if memory_history:
        torch.cuda.memory._record_memory_history(max_entries=100_000)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    train(data_args, model_args, savedirname, 
          origin_finetune_args, mimic_args, linear_finetune_args, downstream_args_dict,
          device)
    
    if memory_history:
        try:
            torch.cuda.memory._dump_snapshot(f"{savedirname}/snapshot.pickle")
        except Exception as e:
            print(f"Failed to capture memory snapshot {e}")

        torch.cuda.memory._record_memory_history(enabled=None)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-s", "--save_dir", type=str, default="./tmp")
    parser.add_argument("-c", "--config", type=str, default="./configs/train_config.json")
    args = parser.parse_args()

    main(args.save_dir, args.config)