import logging
from typing import Optional, Dict
from dataclasses import dataclass
import copy

import torch
import os
import json
from torch import Tensor
from torch import nn
from torch.nn import functional as F
import torch.distributed as dist

from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import ModelOutput
from transformers import Qwen2VLForConditionalGeneration
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from qwen25_vl_model import Qwen25VLForEmbedding

logger = logging.getLogger(__name__)


from dataclasses import dataclass, field
from typing import Optional, Union

@dataclass
class RankerOutput:
    loss: Optional[float] = None
    logits: Optional[Union[list, float]] = None
    pos_weights: Optional[Union[list, float]] = None
    neg_weights: Optional[Union[list, float]] = None


def get_yes_or_no_linear(model_path, random_init=False):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = Qwen25VLForEmbedding.from_pretrained(model_path)

    token_yes = tokenizer.convert_tokens_to_ids("yes")
    token_no = tokenizer.convert_tokens_to_ids("no")

    lm_head_weights = model.lm_head.weight.data

    weight_yes = lm_head_weights[token_yes]
    weight_no = lm_head_weights[token_no]

    D = weight_yes.size()[0]
    linear_layer = torch.nn.Linear(D, 2, bias=False)
    with torch.no_grad():
        if random_init:
            torch.nn.init.xavier_normal_(linear_layer.weight)
        else:
            linear_layer.weight[0] = weight_yes
            linear_layer.weight[1] = weight_no
    return linear_layer


def get_tokens_linear(model_path, tokens_list_path, token_num):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = Qwen25VLForEmbedding.from_pretrained(model_path)

    with open(tokens_list_path, 'r', encoding='utf-8') as f:
        tokens_list = json.load(f)
    
    assert tokens_list[0] == 'yes' and tokens_list[1] == 'no'
    assert token_num > 1
    tokens_list = tokens_list[:token_num]
    tokens = [tokenizer.convert_tokens_to_ids(token) for token in tokens_list]
    lm_head_weights = model.lm_head.weight.data
    weights = [lm_head_weights[token] for token in tokens]
    D = weights[0].size()[0]
    linear_layer = torch.nn.Linear(D, len(tokens_list), bias=False)
    with torch.no_grad():
        for i in range(len(tokens_list)):
            linear_layer.weight[i] = weights[i]
    return linear_layer

def get_token(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    token_yes = tokenizer.convert_tokens_to_ids("yes")
    token_no = tokenizer.convert_tokens_to_ids("no")
    return token_yes, token_no


@dataclass
class EncoderOutput(ModelOutput):
    q_reps: Optional[Tensor] = None
    d_reps: Optional[Tensor] = None
    loss: Optional[Tensor] = None
    scores: Optional[Tensor] = None


class AutoModelForRanking(nn.Module):
    def __init__(
        self,
        model_name_or_path: str,
        train_args = None,
        data_args = None,
        temperature = None,
        pooling: str='last',
        normalize: bool = False,
        use_lora: bool = False,
        lora_config: str = None,
        attn_type: str = None,
        train_type: str = 'point',
        **kwargs,
    ):
        super(AutoModelForRanking, self).__init__()
        self.lm = Qwen25VLForEmbedding.from_pretrained(model_name_or_path, attn_implementation="flash_attention_2", **kwargs)
        self.pooling = pooling
        self.data_args = data_args
        self.train_args = train_args
        self.config = self.lm.config
        self.temperature = temperature
        self.use_lora = False
        self.attn_type = attn_type
        self.normalize = normalize
        self.kl_weight = train_args.kl_weight
        self.loss_method = train_args.loss_method
        self.true_token, self.false_token = get_token(model_name_or_path)

        self.token_list_path = train_args.token_list_path
        self.token_num = train_args.token_num
        if self.token_list_path is not None and self.token_num is not None:
            self.classifier = get_tokens_linear(model_name_or_path,self.token_list_path,self.token_num)
        else:
            self.classifier = get_yes_or_no_linear(model_name_or_path, self.train_args.random_init)
        if self.train_args.freeze_classifier:
            self.classifier.weight.requires_grad = False
        else:
            self.classifier.weight.requires_grad = True
        if self.kl_weight is not None or self.loss_method == 'dpo':
            if self.train_args.reference_model is None:
                self.init_lm = copy.deepcopy(self.lm)
            else:
                self.init_lm = AutoModel.from_pretrained(self.train_args.reference_model, **kwargs)
            self.init_lm.eval()
            if self.kl_weight is not None:
                self.kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
        if use_lora:
            self.use_lora = True
            self.lm.enable_input_require_grads()
            self.lm = get_peft_model(self.lm, lora_config)
            self.lm.print_trainable_parameters()
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        
        self.register_buffer(
            'target_label',
            torch.zeros(self.train_args.per_device_train_batch_size, dtype=torch.long)
        )

        self.train_type = train_type
        
        self.diy_cl_weight_ratio = self.train_args.diy_cl_weight_ratio
        self.limit_ratio = self.train_args.limit_ratio
        self.sft_temperature = self.train_args.sft_temperature

    def margin_loss(self, pos, neg, margin=0.1):
        return torch.mean(F.relu(margin + neg - pos))

    def count_images(self, inputs_ids, image_token_id):
        mask = inputs_ids == image_token_id  
        shifted_mask = torch.cat(
            [torch.zeros((inputs_ids.size(0), 1), dtype=torch.bool, device=inputs_ids.device), 
            mask[:, :-1]], 
            dim=1
        )
        new_image_starts = mask & (~shifted_mask) 
        num_images_per_batch = new_image_starts.sum(dim=1)  
        total_images = num_images_per_batch.sum().item()

        return total_images

    def compute_logits(self, batch, sub_batchsize=None):
        if sub_batchsize is None:
            outputs = self.lm(**batch, output_hidden_states=True)
            last_hidden_state = outputs.last_hidden_state[:,-1]
            logits = self.classifier(last_hidden_state)
            return logits
        all_logits = []
        pixel_index = 0
        img_index = 0
        for i in range(0, len(batch['input_ids']), sub_batchsize):
            end_inx = min(i + sub_batchsize, len(batch['input_ids']))
            sub_features = {}
            for k, v in batch.items():
                num_images = 0
                if k == 'pixel_values':
                    num_images = self.count_images(batch['input_ids'][i:end_inx], self.lm.config.image_token_id)
                    pixel_size = torch.sum(batch['input_ids'][i:end_inx] == self.lm.config.image_token_id).item() * 4
                    if pixel_size == 0:
                        continue
                    sub_features[k] = v[pixel_index:pixel_index+pixel_size]
                    pixel_index = pixel_index + pixel_size
                    sub_features['image_grid_thw'] = batch['image_grid_thw'][img_index:img_index+num_images]
                    img_index += num_images
                elif k == 'image_grid_thw':
                    continue
                else:
                    sub_features[k] = v[i:end_inx]
            sub_outputs  = self.lm(**sub_features)
            sub_last_hidden_state = sub_outputs.last_hidden_state[:,-1]
            sub_logits = self.classifier(sub_last_hidden_state)

            all_logits.append(sub_logits)
            
        all_logits = torch.cat(all_logits, 0)
        return all_logits.contiguous()

    def compute_outputs(self, batch, sub_batchsize=None):
        logits = self.compute_logits(batch, sub_batchsize=sub_batchsize)
        if self.training:
            pos_weights = None
            neg_weights = None
            if self.loss_method.lower() == 'sft':
                labels = torch.ones(logits.size(0)).long().to(logits.device)
                labels[torch.arange(0, len(labels), 1+self.data_args.neg_per_ins)] = 0
                loss = self.cross_entropy(logits, labels)
                with torch.no_grad():
                    scores = F.softmax(logits.float(), dim=-1)[:, 0]
                    scores = scores.reshape(-1, (1+self.data_args.neg_per_ins),)
                    pos_weights = (1 - scores[:, 0])
                    neg_weights = scores[:, 1: ].sum(dim=-1)

            elif self.loss_method.lower() == 'cl':
                if self.train_type != 'point':
                    raise NotImplementedError

                logits = logits[:,0]
                logits = logits.view(
                    -1,
                    (1+self.data_args.neg_per_ins),
                )
                labels = torch.zeros(logits.size(0)).long().to(logits.device)
                loss = self.cross_entropy(logits, labels)
                with torch.no_grad():
                    scores = logits
                    _weights = F.softmax(scores.float(), dim=-1)
                    pos_weights = 1-_weights[:, 0]
                    neg_weights = _weights[:, 1: ].sum(dim=-1)

            elif self.loss_method.lower() == 'clsft':
                if self.train_type != 'point':
                    raise NotImplementedError

                labels_sft = torch.ones(logits.size(0)).long().to(logits.device)
                labels_sft[torch.arange(0, len(labels_sft), 1+self.data_args.neg_per_ins)] = 0
                loss_sft = self.cross_entropy(logits, labels_sft)

                with torch.no_grad():
                    scores = F.softmax(logits, dim=-1)[:, 0]
                    scores = scores.reshape(-1, (1+self.data_args.neg_per_ins),)
                    pos_weights = (1 - scores[:, 0])
                    neg_weights = scores[:, 1: ].sum(dim=-1)

                logits_cl = logits[:,0]
                logits_cl = logits_cl.view(
                    -1,
                    (1+self.data_args.neg_per_ins),
                )
                labels_cl = torch.zeros(logits_cl.size(0)).long().to(logits_cl.device)
                loss_cl = self.cross_entropy(logits_cl, labels_cl)
                loss = (loss_sft+loss_cl)/2

                with torch.no_grad():
                    _weights = F.softmax(logits_cl, dim=-1)[:, 1:].sum(dim=-1)
                    pos_weights += _weights
                    neg_weights += _weights

            elif self.loss_method.lower().startswith('diy'):
                if self.train_type != 'point':
                    raise NotImplementedError

                _, weight_setting, direction_setting = self.loss_method.lower().strip().split('.')
                if weight_setting == 'cl_r':
                    assert direction_setting == 'cl_r'
                if direction_setting == 'cl_r':
                    assert weight_setting == 'cl_r'

                with torch.no_grad():
                    if weight_setting == 'cl_y' or weight_setting == 'cl_r':
                        scores = logits[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                        _weights = F.softmax(scores.float(), dim=-1)
                        pos_weights = 1-_weights[:, 0]
                        neg_weights = _weights[:, 1: ]
                    elif weight_setting == 'cl_limit':
                        scores = logits[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                        _weights = F.softmax(scores.float(), dim=-1)
                        pos_weights = 1-_weights[:, 0]
                        neg_weights = _weights[:, 1: ]

                        pos_weights_sft_gd = torch.ones_like(_weights[:, 0])
                        pos_weights_sft_gd_mask = _weights[:, 0] > (1 - self.limit_ratio)
                        pos_weights_sft_gd[pos_weights_sft_gd_mask] = 0.0
                        neg_weights_sft_gd = torch.ones_like(_weights[:, 1: ])
                        neg_weights_sft_gd_mask = _weights[:, 1:] < self.limit_ratio
                        neg_weights_sft_gd[neg_weights_sft_gd_mask] = 0.0

                        pos_weights = pos_weights * pos_weights_sft_gd
                        neg_weights = neg_weights * neg_weights_sft_gd
                    elif weight_setting == 'cl_nofloat':
                        scores = logits[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                        _weights = F.softmax(scores, dim=-1)
                        pos_weights = 1-_weights[:, 0]
                        neg_weights = _weights[:, 1: ]
                    elif weight_setting == 'sft':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                        pos_weights = 1 - scores[:, 0]
                        neg_weights = scores[:, 1: ]
                    elif weight_setting == 'cl_wogd_eq_sft':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = torch.ones_like(scores[:, 1: ])/self.data_args.neg_per_ins
                    elif weight_setting == 'cl_wogd_neq_sft':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores[:, 1:] / self.sft_temperature, dim=-1)
                    elif weight_setting == 'cl_wogd_eq_sft_wgd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = torch.ones_like(scores[:, 1: ])/self.data_args.neg_per_ins

                        pos_weights_sft_gd = torch.ones_like(scores[:, 0])
                        pos_weights_sft_gd_mask = scores[:, 0] > (1 - self.limit_ratio)
                        pos_weights_sft_gd[pos_weights_sft_gd_mask] = 0.0
                        neg_weights_sft_gd = torch.ones_like(scores[:, 1: ])
                        neg_weights_sft_gd_mask = scores[:, 1:] < self.limit_ratio
                        neg_weights_sft_gd[neg_weights_sft_gd_mask] = 0.0

                        pos_weights = pos_weights * pos_weights_sft_gd
                        neg_weights = neg_weights * neg_weights_sft_gd
                    elif weight_setting == 'cl_wogd_neq_sft_wgd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores[:, 1:], dim=-1)

                        pos_weights_sft_gd = torch.ones_like(scores[:, 0])
                        pos_weights_sft_gd_mask = scores[:, 0] > (1 - self.limit_ratio)
                        pos_weights_sft_gd[pos_weights_sft_gd_mask] = 0.0
                        neg_weights_sft_gd = torch.ones_like(scores[:, 1: ])
                        neg_weights_sft_gd_mask = scores[:, 1:] < self.limit_ratio
                        neg_weights_sft_gd[neg_weights_sft_gd_mask] = 0.0

                        pos_weights = pos_weights * pos_weights_sft_gd
                        neg_weights = neg_weights * neg_weights_sft_gd
                    elif weight_setting == 'cl_wogd_neq_sft_wgd_sft_gd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores[:, 1:] / self.sft_temperature, dim=-1) 

                        pos_weights_sft_gd = 1-scores[:, 0]
                        neg_weights_sft_gd = scores[:, 1: ] 

                        pos_weights = pos_weights * pos_weights_sft_gd
                        neg_weights = neg_weights * neg_weights_sft_gd
                    elif weight_setting == 'cl_wogd_neq_cl_wgd_sft_gd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores / self.sft_temperature, dim=-1)[:,1:]

                        pos_weights_sft_gd = 1-scores[:, 0]
                        neg_weights_sft_gd = scores[:, 1: ] 

                        pos_weights = pos_weights * pos_weights_sft_gd
                        neg_weights = neg_weights * neg_weights_sft_gd
                    elif weight_setting == 'cl_wogd_neq_cl_wgd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores / self.sft_temperature, dim=-1)[:, 1: ]
                    elif weight_setting == 'sft_wogd_neq_cl_wgd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        weights = F.softmax(scores / self.sft_temperature, dim=-1)
                        pos_weights = 1-weights[:, 0]
                        neg_weights = weights[:, 1: ]
                    elif weight_setting == 'flatnce':
                        scores = logits[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores[:, 1:].float(), dim=-1)
                    elif weight_setting == 'cl_wogd_neq_cl_gd':
                        scores = F.softmax(logits.float(), dim=-1)[:, 0].reshape(
                            -1, (1+self.data_args.neg_per_ins),
                        )
                        pos_weights = torch.ones_like(scores[:, 0])
                        neg_weights = F.softmax(scores[:, 1:] / self.sft_temperature, dim=-1) 

                        cl_neg_weight = F.softmax(scores / self.sft_temperature, dim=-1) 

                        pos_weights_sft_gd = 1-cl_neg_weight[:, 0]
                        neg_weights_sft_gd = cl_neg_weight[:, 1: ] 

                        pos_weights = pos_weights * pos_weights_sft_gd
                        neg_weights = neg_weights * neg_weights_sft_gd
                    else:
                        raise NotImplementedError
                
                if direction_setting == 'cl_y' or direction_setting == 'cl_r':
                    directions = logits[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                    pos_directions = -directions[:, 0]
                    neg_directions = directions[:, 1: ]
                elif direction_setting == 'sft':
                    directions_0 = logits[:, 0].reshape(-1, (1+self.data_args.neg_per_ins),)
                    directions_1 = logits[:, 1].reshape(-1, (1+self.data_args.neg_per_ins),)
                    pos_directions = directions_1[:, 0] - directions_0[:, 0]
                    neg_directions = directions_0[:, 1: ] - directions_1[:, 1:]
                else:
                    raise NotImplementedError

                pos_loss = pos_weights * pos_directions
                neg_loss = (neg_weights * neg_directions).sum(dim=-1)
                neg_weights = neg_weights.sum(dim=-1)
                loss = (pos_loss + neg_loss).mean()

            else:
                raise NotImplementedError
                
            ranker_out = RankerOutput(loss=loss, logits=scores, pos_weights=pos_weights, neg_weights=neg_weights)
            return ranker_out
        return logits


    def forward(self, batch, **kwargs):
        sub_batchsize=kwargs.pop('sub_batchsize', None)
        ranker_out = self.compute_outputs(batch, sub_batchsize=sub_batchsize)
        return ranker_out

    def gradient_checkpointing_enable(self, **kwargs):
        kwargs['gradient_checkpointing_kwargs'] = {'use_reentrant': False}
        self.lm.gradient_checkpointing_enable(**kwargs)
        self.lm.gradient_checkpointing=True

    def save_pretrained_new(self, output_path):
        print('output path', output_path)
        self.lm.save_pretrained(output_path)
        print('saving value head')
        torch.save(self.classifier.state_dict(), os.path.join(output_path, 'classifier.bin'))

    def save_pretrained(self, output_path):
        print('output path', output_path)
        self.lm.save_pretrained(output_path)
        print('saving value head')
        torch.save(self.classifier.state_dict(), os.path.join(output_path, 'classifier.bin'))

    def load_pretrained(self, output_path):
        self.lm = self.lm.from_pretrained(output_path)
        self.pooer.load_state_dict(pooler_states)
