# -*- coding: utf-8 -*-

import os
import re
import math
import json
import logging
from typing import Any, Dict, Optional, List, Union, Tuple, Callable
from collections import defaultdict

import scipy
import scipy.stats
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from torch.optim import Optimizer
from tqdm import tqdm
from boltons.iterutils import chunked_iter
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
from accelerate import Accelerator


logger = logging.getLogger('AnglE')


def categorical_crossentropy(y_true: torch.Tensor, y_pred: torch.Tensor, from_logits: bool = True):
    if from_logits:
        return -(F.log_softmax(y_pred, dim=1) * y_true).sum(dim=1)
    return -(torch.log(y_pred, dim=1) * y_true).sum(dim=1)


def cosine_loss(y_true: torch.Tensor, y_pred: torch.Tensor, tau: float = 20.0):
    y_true = y_true[::2, 0]
    y_true = (y_true[:, None] < y_true[None, :]).float()
    y_pred = F.normalize(y_pred, p=2, dim=1)
    y_pred = torch.sum(y_pred[::2] * y_pred[1::2], dim=1) * tau
    y_pred = y_pred[:, None] - y_pred[None, :]
    y_pred = (y_pred - (1 - y_true) * 1e12).view(-1)
    zero = torch.Tensor([0]).to(y_pred.device)
    y_pred = torch.concat((zero, y_pred), dim=0)
    return torch.logsumexp(y_pred, dim=0)


def angle_loss(y_true: torch.Tensor, y_pred: torch.Tensor, tau: float = 1.0):
    y_true = y_true[::2, 0]
    y_true = (y_true[:, None] < y_true[None, :]).float()

    y_pred_re, y_pred_im = torch.chunk(y_pred, 2, dim=1)
    a = y_pred_re[::2]
    b = y_pred_im[::2]
    c = y_pred_re[1::2]
    d = y_pred_im[1::2]
    
    # (a+bi) / (c+di)
    # = ((a+bi) * (c-di)) / ((c+di) * (c-di))
    # = ((ac + bd) + i(bc - ad)) / (c^2 + d^2)
    # = (ac + bd) / (c^2 + d^2) + i(bc - ad)/(c^2 + d^2)
    z = torch.sum(c**2 + d**2, dim=1, keepdim=True)
    re = (a * c + b * d) / z
    im = (b * c - a * d) / z
    
    dz = torch.sum(a**2 + b**2, dim=1, keepdim=True)**0.5
    dw = torch.sum(c**2 + d**2, dim=1, keepdim=True)**0.5
    re /= (dz / dw)
    im /= (dz / dw)

    y_pred = torch.concat((re, im), dim=1)
    y_pred = torch.abs(torch.sum(y_pred, dim=1)) * tau  # absolute delta angle
    y_pred = y_pred[:, None] - y_pred[None, :]
    y_pred = (y_pred - (1 - y_true) * 1e12).view(-1)
    zero = torch.Tensor([0]).to(y_pred.device)
    y_pred = torch.concat((zero, y_pred), dim=0)
    return torch.logsumexp(y_pred, dim=0)


def in_batch_negative_loss(y_true: torch.Tensor,
                           y_pred: torch.Tensor,
                           tau: float = 20.0,
                           similar_matrix: Optional[torch.Tensor] = None,
                           negative_weights: float = 0.0):
    """in-batch negative loss
    """
    device = y_true.device

    def make_target_matrix(y_true: torch.Tensor):
        idxs = torch.arange(0, y_pred.shape[0]).int().to(device)
        y_true = y_true.int()
        idxs_1 = idxs[None, :]
        idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]

        idxs_1 *= y_true.T
        idxs_1 += (y_true.T == 0).int() * -2

        idxs_2 *= y_true
        idxs_2 += (y_true == 0).int() * -1

        y_true = (idxs_1 == idxs_2).float()
        return y_true

    neg_mask = make_target_matrix(y_true == 0)

    y_true = make_target_matrix(y_true)
    if similar_matrix is not None:
        y_true += similar_matrix

    # compute similarity
    y_pred = F.normalize(y_pred, dim=1, p=2)
    similarities = y_pred @ y_pred.T  # dot product
    similarities = similarities - torch.eye(y_pred.shape[0]).to(device) * 1e12
    similarities = similarities * tau

    if negative_weights > 0:
        similarities += neg_mask * negative_weights

    return categorical_crossentropy(y_true, similarities, from_logits=True).mean()


def compute_corrcoef(x, y):
    return scipy.stats.spearmanr(x, y).correlation


def l2_normalize(vecs):
    norms = (vecs**2).sum(axis=1, keepdims=True)**0.5
    return vecs / np.clip(norms, 1e-8, np.inf)


def optimal_threshold(y_true, y_pred):
    loss = lambda t: -np.mean((y_true > 0.5) == (y_pred > np.tanh(t)))
    result = scipy.optimize.minimize(loss, 1, method='Powell')
    return np.tanh(result.x), -result.fun


class AngleDataLoader:
    def __init__(self, data: List[Dict], tokenizer: AutoTokenizer, max_length: int = 512, batch_size: int = 32):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size

    def __len__(self) -> int:
        return math.ceil(len(self.data) / (self.batch_size * 2))

    def make_iter(self, random=False):
        data_idxs = np.arange(0, len(self.data)).tolist()
        if random:
            np.random.shuffle(data_idxs)

        for chunks in chunked_iter(data_idxs, 2 * self.batch_size):
            chunks = [self.data[idx] for idx in chunks]
            batch_texts, batch_labels = [], []
            batch_text_set = defaultdict(list)
            for i, obj in enumerate(chunks, 1):
                text1, text2, label = obj['text1'], obj['text2'], obj['label']
                for j, text in enumerate([text1, text2]):
                    text = text.strip()
                    batch_text_set[text].append(2 * (i - 1) + j)
                    batch_texts.append(text)
                    batch_labels.append([label])

            similar_matrix = np.zeros((len(batch_labels), len(batch_labels)))
            for _, ids in batch_text_set.items():
                if len(ids) > 1:
                    for i in ids:
                        for j in ids:
                            if i == j:
                                continue
                            similar_matrix[i, j] = 1.0
            ret = self.tokenizer(batch_texts, padding='longest', max_length=self.max_length, truncation=True, return_tensors='pt')
            ret['similar_matrix'] = torch.Tensor(similar_matrix)
            batch_labels = torch.Tensor(batch_labels)
            yield ret, batch_labels


class AngleLoss:
    def __init__(self,
                 w1: float = 1.0,
                 w2: float = 1.0,
                 w3: float = 1.0,
                 cosine_tau: float = 20.0,
                 ibn_tau: float = 20.0,
                 angle_tau: float = 1.0):
        self.w1 = w1
        self.w2 = w2
        self.w3 = w3
        self.cosine_tau = cosine_tau
        self.ibn_tau = ibn_tau
        self.angle_tau = angle_tau

    def __call__(self,
                 y_true: torch.Tensor,
                 y_pred: torch.Tensor,
                 similar_matrix: Optional[torch.Tensor] = None) -> torch.Tensor:
        loss = 0.
        if self.w1 > 0:
            loss += self.w1 * cosine_loss(y_true, y_pred, self.cosine_tau)
        if self.w2 > 0:
            loss += self.w2 * in_batch_negative_loss(y_true, y_pred, self.ibn_tau, similar_matrix=similar_matrix)
        if self.w3 > 0:
            loss += self.w3 * angle_loss(y_true, y_pred, self.angle_tau)
        return loss


class AnglE(nn.Module):
    cfg_file_name = 'angle.config'

    def __init__(self,
                 model_name_or_path: str,
                 max_length: int = 512,
                 model_kwargs: Optional[Dict] = None,
                 lora_config_kwargs: Optional[Dict] = None,
                 apply_lora: bool = False,
                 pooling_strategy: str = 'cls',
                 train_mode: bool = True,
                 loss: AngleLoss = None,
                 **kwargs: Any):
        super().__init__()
        self.max_length = max_length
        self.pooling_strategy = pooling_strategy
        self.train_mode = train_mode

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        model_kwargs = model_kwargs if model_kwargs is not None else {}
        self.backbone = AutoModel.from_pretrained(model_name_or_path, **model_kwargs)

        if apply_lora:
            base_lora_config = {
                'task_type': TaskType.FEATURE_EXTRACTION,
                'r': 8,
                'lora_alpha': 32,
                'lora_dropout': 0.1,
            }
            if lora_config_kwargs is not None:
                base_lora_config.update(lora_config_kwargs)
            peft_config = LoraConfig(**base_lora_config)
            self.backbone = get_peft_model(self.backbone, peft_config)
            self.backbone.print_trainable_parameters()

        if train_mode:
            self.loss = AngleLoss() if loss is None else loss

        self.__cfg = {
            'model_name_or_path': model_name_or_path,
            'max_length': max_length,
            'model_kwargs': model_kwargs,
            'lora_config_kwargs': lora_config_kwargs,
            'pooling_strategy': pooling_strategy,
            'apply_lora': apply_lora
        }

    @staticmethod
    def find_pth_path(dirpath: str, config: Dict) -> str:
        if config['save_mode'] == 'best':
            return os.path.join(dirpath, config['best_file_name'])
        
        pth_list = []
        for fname in os.listdir(dirpath):
            if fname.endswith('.pth'):
                epoch = int(re.search(r'\d+', fname).group())
                pth_list.append((epoch, fname))
        pth_list = sorted(pth_list, key=lambda x: x[0], reverse=True)
        return os.path.join(dirpath, pth_list[0][1])

    @staticmethod
    def from_pretrained(model_name_or_path: str,
                        train_mode: bool = False,
                        model_kwargs: Optional[Dict] = None, 
                        load_kwargs: Optional[Dict] = None):
        if os.path.exists(model_name_or_path):
            load_kwargs = {} if load_kwargs is None else load_kwargs
            config = AnglE.load_config(os.path.join(model_name_or_path, AnglE.cfg_file_name))
            if model_kwargs is not None:
                config.update(model_kwargs)
            angle = AnglE(**config, train_mode=train_mode)
            pth_path = AnglE.find_pth_path(model_name_or_path, config)
            logger.info(f'Load pretrained model from {pth_path}...')
            angle.load_state_dict(torch.load(pth_path, **load_kwargs))
            return angle

    def forward(self, **kwargs):
        outputs = self.backbone(**kwargs)
        outputs = outputs.last_hidden_state
        if self.pooling_strategy == 'cls':
            outputs = outputs[:, 0]
        elif self.pooling_strategy == 'cls_avg':
            outputs = (outputs[:, 0] + torch.mean(outputs, dim=1)) / 2.0
        elif self.pooling_strategy == 'last':
            outputs = outputs[:, -1]
        elif self.pooling_strategy == 'avg':
            outputs = torch.mean(outputs, dim=1)
        elif self.pooling_strategy == 'max':
            outputs, _ = torch.max(outputs, dim=1)
        else:
            raise NotImplementedError('please specify pooling_strategy from [`cls`, `last`, `avg`, `max`]')
        return outputs

    def compute_loss(self, y_true: torch.Tensor, y_pred: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        return self.loss(y_true, y_pred, **kwargs)

    @staticmethod
    def load_config(fpath: str) -> Dict:
        with open(fpath, 'r', encoding='utf-8') as reader:
            return json.load(reader)

    def save_config(self, fpath: str):
        with open(fpath, 'w', encoding='utf-8') as writer:
            json.dump(self.__cfg, writer, ensure_ascii=False, indent=2)

    def fit(self,
            train_ds: AngleDataLoader,
            valid_ds: Optional[AngleDataLoader] = None,
            save_best: bool = True,
            best_file_name: str = 'best.pth',
            output_dir: Optional[str] = None,
            epochs: int = 1,
            optimizer: Optional[Optimizer] = None,
            scheduler: Optional[Callable] = None,
            apply_scheduler: bool = False,
            learning_rate: float = 1e-5,
            max_grad_norm: float = 1.0,
            weight_decay: float = 0.01,
            warmup_steps: int = 10000) -> Dict:
        assert best_file_name.endswith('.pth'), '`best_file_name` have to end with .pth!'
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)
        self.__cfg['save_mode'] = 'best' if save_best else 'epoch'
        self.__cfg['best_file_name'] = best_file_name
        # save config
        self.save_config(os.path.join(output_dir, AnglE.cfg_file_name))

        accelerator = Accelerator()
        device = accelerator.device
        if optimizer is None:
            model_params = list(self.named_parameters())
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in model_params if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
                {'params': [p for n, p in model_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
            optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

        if scheduler is None and apply_scheduler:
            scheduler = transformers.get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=len(train_ds) * epochs)

        self, optimizer, train_ds = accelerator.prepare(self, optimizer, train_ds)
        self.train()

        history = {
            'train_loss': [],
            'val_corrcoef': [],
            'val_accuracy': []
        }
        best_corrcoef = 0
        for epoch in range(epochs):
            pbar = tqdm(train_ds.make_iter(random=True), total=len(train_ds))
            total_loss = 0
            for X, y in pbar:
                X = X.to(device)
                y = y.to(device)

                optimizer.zero_grad()
                similar_matrix = X.pop('similar_matrix', None)
                output = self.forward(**X)
                loss = self.compute_loss(y, output, similar_matrix=similar_matrix)
                accelerator.backward(loss)
                if max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
                optimizer.step()
                if apply_scheduler:
                    scheduler.step()
                loss_value = round(loss.detach().cpu().item(), 5)
                total_loss += loss_value
                pbar.set_description(f'Epoch {epoch} - Loss {loss_value}')

            history['train_loss'].append(round(total_loss / len(train_ds), 5))
            # evaluate
            if valid_ds is not None:
                corrcoef, accuracy = self.evaluate(valid_ds, device=device)
                history['val_corrcoef'].append(corrcoef)
                history['val_accuracy'].append(accuracy)
                if corrcoef > best_corrcoef:
                    best_corrcoef = corrcoef
                    if output_dir is not None and save_best:
                        torch.save(self.state_dict(), os.path.join(output_dir, best_file_name))
                        pbar.write('save best!')
                if output_dir is not None and not save_best:
                    torch.save(self.state_dict(),  os.path.join(output_dir, f'{epoch}.pth'))
                corrcoef = round(corrcoef, 5)
                accuracy = round(accuracy, 5)
                pbar.write(f'Epoch {epoch} - avg_loss {history["train_loss"][-1]} - val_accuracy {accuracy} - val_corrcoef {corrcoef} - best_val_corrcoef {round(best_corrcoef, 5)}')
            else:
                pbar.write(f'Epoch {epoch} - avg_loss {history["train_loss"][-1]}')
                history.pop('val_corrcoef', None)
                history.pop('val_accuracy', None)
        return history

    def evaluate(self, data: AngleDataLoader, threshold: Optional[float] = None, device: Any = None):
        self.eval()
        y_trues, y_preds = [], []
        for X, y in data.make_iter(random=False):
            X.pop('similar_matrix', None)
            y_trues.extend(y[::2, 0].detach().cpu().numpy())
            with torch.no_grad():
                X.to(device)
                x_vecs = self.forward(**X).detach().float().cpu().numpy()
            x_vecs = l2_normalize(x_vecs)
            pred = (x_vecs[::2] * x_vecs[1::2]).sum(1)
            y_preds.extend(pred)

        y_trues, y_preds = np.array(y_trues), np.array(y_preds)
        corrcoef = compute_corrcoef(y_trues, y_preds)
        if threshold is None:
            _, accuracy = optimal_threshold(y_trues, y_preds)
        else:
            accuracy = np.mean((y_trues > 0.5) == (y_preds > threshold))
        return corrcoef, accuracy

    def encode(self, sentences: Union[List[str], Tuple[str], str], device: Any = None):
        if device is None:
            device = self.backbone.device
        self.eval()
        if isinstance(sentences, str):
            sentences = [sentences]
        tok = self.tokenizer(sentences, padding='longest', max_length=self.max_length, truncation=True, return_tensors='pt')
        tok.to(device)
        with torch.no_grad():
            return self.forward(**tok)

    def export_onnx(self):
        pass

