import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    Qwen2ForCausalLM,
    Qwen2Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    PeftModel, 
    PeftConfig
)
from datetime import datetime
import pytz
from gsm8k.po_rewrite.test_po_ft import process_generated_prompt_no_question
from gsm8k.evaluation.metrics import parse_answer_2, extract_answers, em

TEMPLATE = ("<|im_start|>system\n"
            "You are a helpful assistant.<|im_end|>\n"
            "<|im_start|>user\n"
            "{input}<|im_end|>\n"
            "<|im_start|>assistant\n"
            "{output}")

COMPLETE_TEMPLATE = ("<|im_start|>system\n"
            "You are a helpful assistant.<|im_end|>\n"
            "<|im_start|>user\n"
            "{input}<|im_end|>\n"
            "<|im_start|>assistant\n"
            "{output}<|im_end|>")

LABEL_TEMPLATE = "{output}<|im_end|>"

QUERY_TEMPLATE = ("<|im_start|>system\n"
                  "You are a helpful assistant.<|im_end|>\n"
                  "<|im_start|>user\n"
                  "{input}<|im_end|>\n"
                  "<|im_start|>assistant\n")


class Qwen2JointConfig(PretrainedConfig):
    model_type = "qwen2_joint"

    def __init__(
            self,
            encoding_dim=8,
            input_dim=3584, 
            embedding_dim=3584,
            bottleneck=16,
            hyperlambda=0.01,
            pg_t=0,
            model1_path="Qwen2.5-7B",
            model2_path="Qwen2.5-7B",
            tokenizer_path="Qwen2.5-7B",
            shared_layers=28,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.encoding_dim = encoding_dim
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.bottleneck = bottleneck
        self.model1_path = model1_path
        self.model2_path = model2_path
        self.tokenizer_path = tokenizer_path
        self.hyperlambda = hyperlambda
        self.pg_t = pg_t
        self.shared_layers = shared_layers


class HyperParamNet(nn.Module):
    def __init__(self, linear1, linear2, dim, bottleneck):
        super().__init__()
        self.linear1 = linear1
        self.linear2 = linear2
        self.dim = dim
        self.bottleneck = bottleneck
        self.condition_var = None

    def set_condition_var(self, condition_var):
        self.condition_var = condition_var

    def forward(self, condition_var):
        tmp = F.relu(self.linear1(condition_var))
        output = self.linear2(tmp)                             # (batch_size, dim * bottleneck)
        output = output.view(-1, self.dim, self.bottleneck)    # (batch_size, dim, bottleneck)
        return output


class HyperNet(nn.Module):
    def __init__(self, input_dim, embedding_dim, bottleneck=16, hidden_size=3584, encoding_dim=8):
        super(HyperNet, self).__init__()
        self.condition_proj = nn.Linear(hidden_size, encoding_dim)

        self.pre_down_linear = nn.Identity()
        self.down_linear = nn.Linear(encoding_dim + 1, input_dim * bottleneck)
        self.down_linear.weight.data.normal_(0, 1e-7)
        self.down_linear.bias.data.zero_()

        self.pre_up_linear = nn.Identity()
        self.up_linear = nn.Linear(encoding_dim + 1, embedding_dim * bottleneck)
        self.up_linear.weight.data.normal_(0, 1e-7)
        self.up_linear.bias.data.zero_()

        self.down_hypernet = HyperParamNet(self.pre_down_linear, self.down_linear, input_dim, bottleneck)
        self.up_hypernet = HyperParamNet(self.pre_up_linear, self.up_linear, bottleneck, embedding_dim)


class HyperLora(nn.Module):
    def __init__(self, linear: nn.Module, hypernet1=None, hypernet2=None, hyperlambda=0.01, idx=None):
        super().__init__()
        self.linear = linear
        self.hypernet1 = hypernet1
        self.hypernet2 = hypernet2
        self.dropout = nn.Dropout(p=0.1)
        self.hyperlambda = hyperlambda
        self.idx = idx

    def forward(self, x):
        val = self.hypernet1.condition_var
        if val is None:
            raise ValueError("condition_var has not been set in hypernet1.")
        if self.idx is not None:
            batch_size, encoding_dim = val.shape
            idx_tensor = torch.full((batch_size, 1), self.idx, dtype=val.dtype, device=val.device)
            val = torch.cat([val, idx_tensor], dim=1)  # shape: (batch_size, encoding_dim + 1)

        weight1 = self.hypernet1(val)  # shape: (batch_size, input_dim, bottleneck)
        weight2 = self.hypernet2(val)  # shape: (batch_size, bottleneck, embedding_dim)

        weight1 = weight1.mean(dim=0)  # shape: (input_dim, bottleneck)
        weight2 = weight2.mean(dim=0)  # shape: (bottleneck, embedding_dim)

        lora_out = (x @ weight1) @ weight2  # (batch_size, embedding_dim)
        out = self.linear(x) + self.hyperlambda * lora_out
        return out


class Qwen2ForJointLM(PreTrainedModel):
    config_class = Qwen2JointConfig

    def __init__(self, config: Qwen2JointConfig):
        super(Qwen2ForJointLM, self).__init__(config)
        self.tokenizer = Qwen2Tokenizer.from_pretrained(
            config.tokenizer_path,
            padding_side="left"
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.max_length = 512
        self.shared_layers = config.shared_layers

        model1 = Qwen2ForCausalLM.from_pretrained(config.model1_path)
        lora_config = LoraConfig(
            r=16,
            bias="none",
            task_type="CAUSAL_LM",
        )
        self.model1 = get_peft_model(model1, lora_config)
        self.pg_t = config.pg_t
        self.model2 = Qwen2ForCausalLM.from_pretrained(config.model2_path)
        for param in self.model2.parameters():
            param.requires_grad = False
        self.hypernets = nn.ModuleList()

        target_modules = ["o_proj"]
        for name, module in self.model2.named_modules():
            if len(name.split('.')) > 3:
                layer_id = eval(name.split('.')[2])
            else:
                layer_id = -1
            if name.split('.')[-1] in target_modules and isinstance(module, nn.Linear):
                hypernet = HyperNet(
                    input_dim=module.in_features,
                    embedding_dim=module.out_features,
                    bottleneck=config.bottleneck,
                    hidden_size=config.input_dim,
                    encoding_dim=config.encoding_dim
                )
                self.hypernets.append(hypernet)
                hyper_lora = HyperLora(
                    linear=module,
                    hypernet1=hypernet.down_hypernet,
                    hypernet2=hypernet.up_hypernet,
                    hyperlambda=config.hyperlambda,
                    idx=layer_id
                )
                parent = self._get_parent_module(self.model2, name)
                setattr(parent, name.split('.')[-1], hyper_lora)

        print("Replaced with HyperLora.")

    def _get_parent_module(self, model, module_name):
        components = module_name.split('.')[:-1]
        parent = model
        for comp in components:
            parent = getattr(parent, comp)
        return parent

    def batch_tokenize(self, batch_inputs, batch_outputs, model="model2"):
        if model == "model1":
            full_templates = [COMPLETE_TEMPLATE.format(input=input_text, output=output_text)
                          for input_text, output_text in zip(batch_inputs, batch_outputs)]
        else:
            full_templates = [TEMPLATE.format(input=input_text, output=output_text)
                            for input_text, output_text in zip(batch_inputs, batch_outputs)]

        full_encodings = self.tokenizer(full_templates,
                                        return_tensors="pt",
                                        padding="max_length",
                                        truncation=True,
                                        max_length=self.max_length)
        input_ids = full_encodings["input_ids"]
        input_attention = full_encodings["attention_mask"]

        return input_ids, input_attention

    
    def batch_tokenize_labels(self, batch_outputs):
        label_templates = [LABEL_TEMPLATE.format(output=output_text)
                           for output_text in batch_outputs]

        label_tokens = self.tokenizer(label_templates,
                                      return_tensors="pt",
                                      padding="max_length",
                                      truncation=True,
                                      max_length=self.max_length)
        
        labels = torch.tensor([[-100 if token == self.tokenizer.pad_token_id else token
                               for token in label_seq] for label_seq in label_tokens["input_ids"]])

        return labels

    
    def get_rewards(self, generated_responses, answers):
        rewards = []
        for p, g in zip(generated_responses, answers):
            if p and em(extract_answers(p), parse_answer_2(g)):
                rewards.append(1)
            else:
                rewards.append(0)
        return rewards


    def forward(self, meta_prompt, meta_input_ids, meta_attention_mask, question, answer, labels):
        model1_input_ids = meta_input_ids.to(self.model1.device)
        model1_attention_mask = meta_attention_mask.to(self.model1.device)

        outputs1 = self.model1(
            input_ids=model1_input_ids,
            attention_mask=model1_attention_mask,
            labels=None, 
            output_hidden_states=True
        )

        hidden_states = outputs1.hidden_states
        shared_layer = hidden_states[self.shared_layers]
        condition_var = shared_layer.mean(dim=1)
        condition_var = self.hypernets[0].condition_proj(condition_var)
        for hypernet in self.hypernets:
            hypernet.up_hypernet.set_condition_var(condition_var)
            hypernet.down_hypernet.set_condition_var(condition_var)

        with torch.no_grad():
            generation_kwargs = {
                "max_new_tokens": 100,
                "temperature": self.pg_t,
                "top_p": 1, 
                "num_return_sequences": 1, 
                "num_beams": 1,
                "do_sample": True,
                "repetition_penalty": 1.0, 
                "pad_token_id": self.tokenizer.eos_token_id
            }

            generated_ids = self.model1.generate(
                input_ids=model1_input_ids,
                attention_mask=model1_attention_mask,
                **generation_kwargs
            )

        decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        generated_texts = [process_generated_prompt_no_question(t.replace(mp, '')) for mp, t in zip(meta_prompt, decoded_texts)]

        model1_input_ids_train, model1_attention_mask_train = self.batch_tokenize(meta_prompt, generated_texts, "model1")
        model1_input_ids_train = model1_input_ids_train.to(self.model1.device)
        model1_attention_mask_train = model1_attention_mask_train.to(self.model1.device)
        model1_labels_train = self.batch_tokenize_labels(generated_texts)
        outputs1 = self.model1(
            input_ids=model1_input_ids_train,
            attention_mask=model1_attention_mask_train,
            labels=model1_labels_train,
        )

        concatenated_inputs = [gp + f"\nQuestion: {q}\nLet's think step by step.\n" for gp, q in zip(generated_texts, question)]
        model2_input_ids, model2_attention_mask = self.batch_tokenize(concatenated_inputs, answer)
        model2_input_ids = model2_input_ids.to(self.model2.device)
        model2_attention_mask = model2_attention_mask.to(self.model2.device)
        outputs2 = self.model2(
            input_ids=model2_input_ids,
            attention_mask=model2_attention_mask,
            labels=labels,
        )

        with torch.no_grad():
            full_encodings = self.tokenizer(concatenated_inputs,
                                            return_tensors="pt",
                                            padding="max_length",
                                            truncation=True,
                                            max_length=self.max_length)
            input_ids = full_encodings["input_ids"].to(self.model2.device)
            input_attention = full_encodings["attention_mask"].to(self.model2.device)
        
            generation_kwargs = {
                "max_new_tokens": 256, 
                "top_p": 1,
                "num_return_sequences": 1,
                "num_beams": 1,
                "do_sample": False,
                "repetition_penalty": 1.0, 
                "pad_token_id": self.tokenizer.eos_token_id
            }
        
            generated_ids = self.model2.generate(
                input_ids=input_ids,
                attention_mask=input_attention,
                **generation_kwargs
            )
            model2_responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            rewards = self.get_rewards(model2_responses, answer)

        return outputs2, outputs1, model1_labels_train, rewards
