from typing import List, Dict
from typing import Tuple
from collections import defaultdict, Counter
from tqdm import tqdm
import re
import numpy as np

import torch
import spacy
from nltk.tokenize.casual import casual_tokenize
from nltk.util import ngrams
from nltk.corpus import wordnet
from transformers import AutoTokenizer, AutoModelForSequenceClassification


class POSCalculator:

    def __init__(self):
        self.nlp = spacy.load("en_core_web_sm")

    def calculate(self, instances: List[str]) -> Dict[str, int]:
        metric_value = self._calculate(instances)
        return metric_value
    
    def _calculate(self, instances: List[str]) -> Dict[str, int]:

        pos_list = defaultdict(int)
        for instance in instances:
            #if instance == '':
            #    continue
            doc = self.nlp(instance.lower())
            for token in doc:
                pos_list[token.pos_] += 1

        return pos_list

class AvgLenCalculator:

    def calculate(self, instances: List[str]) -> float:
        metric_value = self._calculate(instances)
        return metric_value
    
    def _calculate(self, instances: List[str]) -> float:
        tokenized_responses = [
            len(casual_tokenize(instance.lower())) for instance in instances #if instance != '' else 0 
        ]
        return sum(tokenized_responses)/len(tokenized_responses)

class LenDistCalculator:

    def calculate(self, instances: List[str]) -> float:
        metric_value = self._calculate(instances)
        return metric_value
    
    def _calculate(self, instances: List[str]) -> float:
        tokenized_responses = [
            len(casual_tokenize(instance.lower())) for instance in instances #if instance != '' else 0 
        ]
        stat = defaultdict(int)
        for ele in tokenized_responses:
            stat[ele] += 1
        
        for k, v in stat.items():
            stat[k] = 100*v/len(tokenized_responses)

        return stat

class UniqueNGramCalculator:
    def __init__(self, n: int, th: int):
        self.n = n
        self.th = th
    
    def calculate(self, instances: List[str]) -> int:
        metric_value = self._calculate(instances)
        return metric_value

    def _calculate(self, instances: List[str]) -> int:
        tokenized_responses = [
            casual_tokenize(instance.lower()) for instance in instances #if instance != ''
        ]

        num_all_ngrams = 0
        all_ngram_list = list()

        for tokens in tokenized_responses:
            token_ngrams = list(ngrams(tokens, self.n))
            num_all_ngrams += len(token_ngrams)
            all_ngram_list.extend(token_ngrams)

        counter = Counter(all_ngram_list)
        num_unique_ngrams = 0
        for k, v in list(counter.items()):
            if v < self.th:
                continue
            num_unique_ngrams += 1
        
        return num_unique_ngrams

class UniqueUniGramCalculator(UniqueNGramCalculator):
    def __init__(self):
        super().__init__(n=1, th=1)

class UniqueBiGramCalculator(UniqueNGramCalculator):
    def __init__(self):
        super().__init__(n=2, th=1)

class UniqueTriGramCalculator(UniqueNGramCalculator):
    def __init__(self):
        super().__init__(n=3, th=1)

class DistKCalculator:
    def __init__(self, k: int):
        self.k = k
    
    def calculate(self, instances: List[str]) -> Tuple[float, int]:
        metric_value = self._calculate(instances)
        return metric_value
    
    def _calculate(self, instances: List[str]) -> Tuple[float, int]:
        tokenized_responses = [
            casual_tokenize(instance.lower()) for instance in instances #if instance != ''
        ]
        num_all_ngrams = 0
        all_ngram_set = set()

        for tokens in tokenized_responses:
            token_ngrams = list(ngrams(tokens, self.k))
            num_all_ngrams += len(token_ngrams)
            all_ngram_set.update(token_ngrams)

        dist_score = len(all_ngram_set) / num_all_ngrams
        return dist_score, len(all_ngram_set)


class Dist1Calculator(DistKCalculator):
    def __init__(self):
        super().__init__(k=1)


class Dist2Calculator(DistKCalculator):
    def __init__(self):
        super().__init__(k=2)


class Dist3Calculator(DistKCalculator):
    def __init__(self):
        super().__init__(k=3)

class EntKCalculator:
    def __init__(self, k: int):
        self.k = k

    def calculate(self, instances: List[str]) -> Tuple[float, int]:
        metric_value = self._calculate(instances)
        return metric_value

    def _calculate(self, instances: List[str]) -> Tuple[float, int]:
        tokenized_responses = [
            casual_tokenize(instance.lower()) for instance in instances #if instance != ''
        ]
        num_all_ngrams = 0
        all_ngram_list = list()

        for tokens in tokenized_responses:
            token_ngrams = list(ngrams(tokens, self.k))
            num_all_ngrams += len(token_ngrams)
            all_ngram_list.extend(token_ngrams)

        counter = Counter(all_ngram_list)
        ent = 0
        for k, v in list(counter.items()):
            prob = float(v) / float(num_all_ngrams)
            ent -= prob * np.log2(prob)

        return ent, len(all_ngram_list)


class Ent1Calculator(EntKCalculator):
    def __init__(self):
        super().__init__(k=1)


class Ent2Calculator(EntKCalculator):
    def __init__(self):
        super().__init__(k=2)


class Ent3Calculator(EntKCalculator):
    def __init__(self):
        super().__init__(k=3)

class WordNetKCalculator:
    
    def __init__(self, th=1):
        self.th = th

    def calculate(self, instances):
        metric_value = self._calculate(instances)
        return metric_value

    def _calculate(self, instances):
        stats = defaultdict(int)
        for instance in instances:
            #if instance == '':
            #    continue
            tok_instance = casual_tokenize(instance.lower())

            for tok in tok_instance:
                syn = wordnet.synsets(tok)
                if len(syn) == 0:
                    continue

                try:
                    category = syn[0].hypernyms()[0]
                    category = category.name()
                    stats[category] += 1
                except IndexError:
                    stats[syn[0].name()] += 1
        
        total_count = 0
        for k, v in stats.items():
            if v > self.th:
                total_count += 1
            
        return total_count

class WordNet1Calculator(WordNetKCalculator):
    def __init__(self):
        super().__init__(th=1)

class WordNet2Calculator(WordNetKCalculator):
    def __init__(self):
        super().__init__(th=2)

class ObjectConsistencyCalculator:

    def calculate(self, instances):
        metric_value = self._calculate(instances)
        return metric_value

    def _calculate(self, instances):
        recall = []
        for instance in tqdm(instances, total=len(instances)):
            #if instance['first_deid_caption'] == "":
            #    recall.append(0)
            #    continue
            #if not instance["correct_sample"]:
            #    continue

            photochat_caption = instance["photo_description"]
            target_objects = photochat_caption.lower().split("objects in the photo: ")[-1].split(', ')

            gen_caption = instance["first_deid_caption"].lower().split()

            cnt = 0
            for obj in target_objects:
                if obj in gen_caption:
                    cnt += 1
            
            recall.append(cnt/len(target_objects))
        
        return sum(recall)/len(recall)

class DialogueConsistencyCalculator:

    def __init__(self):
        consistent_clf_path = 'ynie/roberta-large_conv_contradiction_detector_v0'
        self.consistent_tokenizer = AutoTokenizer.from_pretrained(consistent_clf_path)
        self.consistent_clf = AutoModelForSequenceClassification.from_pretrained(consistent_clf_path)
        self.consistent_clf.to('cuda')
        self.consistent_clf.eval()

        self.device = "cuda"

    def calculate(self, instances):
        metric_value = self._calculate(instances)
        return metric_value

    def _calculate(self, instances):

        pattern = r'(?<=Dialogue:\n)([\s\S]*?)(?=\n\n)'

        total_result = []
        for instance in tqdm(instances, total=len(instances)):
            gen_caption = instance["first_deid_caption"]
            #if gen_caption == '':
            #    total_result.append(0.)
            #    continue
            
            match = re.search(pattern, instance["task2_prompt_input"])
            if match:
                dialogue = match.group(1)
            else:
                print("No match found")
            
            dialog = []
            for utter in dialogue.split('\n'):
                utter = utter.split(": ")[-1]
                if utter == "[Sharing Image]":
                    break
                dialog.append(utter)
            dialog = ' <TURN> '.join(dialog)

            consistency_result = self.get_consistency(dialog, gen_caption.split('An image of ')[-1])
            total_result.append(consistency_result["non_contradiction"])
        
        return sum(total_result)/len(total_result)

    def get_consistency(self, premise, hyp):
        tokenized_input_seq_pair = self.consistent_tokenizer.encode_plus(premise, hyp, max_length=128, return_token_type_ids=True, truncation=True)
        input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0).to(self.device)
        # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
        token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0).to(self.device)
        attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.consistent_clf(input_ids,
                                          attention_mask=attention_mask,
                                          token_type_ids=token_type_ids,
                                          labels=None)
            
            predicted_prob = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one

        return {'non_contradiction': predicted_prob[0], 'contradiction': predicted_prob[1]}

