import os.path

import torch

from metrics.attention_head import normalize_token_avg_attn_entropies
from utils.prompt import get_prompt_head_entropy, get_prompts, get_prompt_outputs, get_prompt_outputs_attention
from utils.output import *#get_representations, get_lss, get_equidistance, get_expodistance, get_pairwise_similarity, get_mean_norm, get_representation_pca, get_attention_variance
from utils.utils import timestamp
from utils.align import compute_alignment

class Experiment:

    def __init__(self, 
                models, 
                model_types,
                tokenizer, 
                model_name, 
                out_dir="out", 
                hidden_reps         = False,
                pca                 = False,
                plot_pca            = False,
                lss                 = False, 
                diffs               = False,
                equidistance        = False, 
                expodistance        = False,
                expodistance2       = False,
                norm                = False, 
                head_entropy        = False, 
                similarity          = False, 
                attentions          = False,
                attention_var       = False,
                attention_var2      = False,
                classifier          = False,
                rep_classifier      = False, 
                rep_var             = False,
                rep_cs              = False,
                rep_class_cs        = False,
                similarity_norm     = False,
                alignment           = False
                ):
        """

        :param model:
        :param tokenizer:

        :param model_name:      name of model for outputs
        :param out_dir:         output directory

        :param lss:             True/False to include/exclude
        :param equidistance:    True/False to include/exclude
        :param norm:            True/False to include/exclude
        :param head_entropy:    True/False to include/exclude
        """
        self.models          = models
        self.model_types     = model_types
        self.tokenizer       = tokenizer
        self.model_name      = model_name
        self.out_dir         = out_dir

        self.hidden_reps     = hidden_reps
        self.pca             = pca
        self.plot_pca        = plot_pca
        self.attentions      = attentions
        self.lss             = lss
        self.diffs           = diffs
        self.equidistance    = equidistance
        self.expodistance    = expodistance
        self.expodistance2   = expodistance2
        self.norm            = norm
        self.attention_var   = attention_var
        self.attention_var2  = attention_var2
        self.head_entropy    = head_entropy
        self.classifier      = classifier
        self.rep_var         = rep_var
        self.rep_cs          = rep_cs
        self.rep_class_cs    = rep_class_cs
        self.rep_classifier  = rep_classifier
        self.similarity      = similarity
        self.similarity_norm = similarity_norm
        self.alignment       = alignment

    def run(self, dataset, data_name, start=None, end=None, save=True):

        prompts, start, end = get_prompts(dataset, data_name, start=start, end=end)

        tokenizer = self.tokenizer
        model_types = self.model_types
        models = self.models


        outputs = {
            "DATA_NAME": data_name
        }

        out = {}


        for i, prompt in zip(range(start, end), prompts):

            timestamp(f"Running prompt {i + 1} of {end}")
            out[i] = {}

            tokens = tokenizer(prompt, return_tensors='pt')
            input_ids = tokens.input_ids
            # timestamp(f"Tokens loaded")

            for idx in range(len(models)):
                model_type = model_types[idx]
                model = models[idx]
                out[i][model_type] = {}

                model_output    = get_prompt_outputs_attention(model, input_ids)
                attentions      = model_output.attentions
                representations = get_representations(model_output.hidden_states)
                nlayers = representations.shape[0]

                # representations = representations.to(torch.float32)
                # timestamp("Representations obtained")     

                if self.attentions:
                    out[i][model_type]["attentions"] = attentions
                    # timestamp("attentions loaded")                
                
                if self.hidden_reps:
                    out[i][model_type]["hidden_states"] = model_output.hidden_states
                    # timestamp("hidden states loaded")

                if self.pca:
                    out[i][model_type]["pca"] = get_representation_pca(representations)
                    # timestamp("pca complete")

                if self.plot_pca:
                    plot_pca(representations, data_name, dim=-2)

                if self.lss:
                    out[i][model_type]["lss"] = get_lss(representations)
                    # timestamp("lss complete")

                if self.diffs:
                    out[i][model_type]["diffs"] = get_diffs(representations[1:-1][1:-1]) # remove first/last layer/token

                if self.norm:
                    out[i][model_type]["norm"] = get_mean_norm(representations[1:-1])

                # if self.equidistance:
                #     out[i][model_type]["equidistance"] = get_equidistance(representations) 
                #     # timestamp("equidistance complete")

                # if self.expodistance:
                #     out[i][model_type]["expodistance"] = get_expodistance(representations.to(torch.float32)) 
                #     # timestamp("expodistance complete")

                # if self.expodistance2:
                #     rep_restricted = representations[1:-1] # remove first and last layers
                #     out[i][model_type]["expodistance2"] = get_expodistance2(rep_restricted.to(torch.float32)) 

                    
                
                    # timestamp("norm complete")

                # if self.attention_var: # Variance of the average attention head
                #     out[i][model_type]["attention_var"] = get_attention_variance(attentions, nlayers-1)
                #     # timestamp("attention variance complete")

                # if self.attention_var2: # Average variance of attention matrices for each head (assumed contiguous)
                #     out[i][model_type]["attention_var2"] = get_attention_variance2(attentions, nlayers-1)
                #     # timestamp("attention variance 2 complete")

                # if self.head_entropy:
                # # out[i]["head_entropy"] = get_prompt_head_entropy(self.model, self.tokenizer, prompt)
                #     avg = get_prompt_head_entropy(model_output)
                #     token_avg, global_avg = normalize_token_avg_attn_entropies(avg)
                #     out[i][model_type]["head_entropy"] = token_avg
                #     out[i][model_type]["head_entropy_avg"] = global_avg
                #     # timestamp("trained entropy complete")

                # if self.classifier: # Mean normed difference of attention matrices and YWWY
                #     out[i][model_type]["classifier"] = get_model_classifier(model, tokenizer, nlayers-1, input_ids, attentions)
                #     # timestamp("classifier variance complete")

                # if self.rep_var: # Variation of softmax(X_l X_l^T) across layers
                #     out[i][model_type]["rep_var"] = get_representation_variation(representations)

                # if self.rep_classifier: # Mean normed difference of YWWY and softmax(X_l X_l^T) across layers
                #     out[i][model_type]["rep_classifier"] = get_representation_classifier(model, tokenizer, input_ids, representations)

                # if self.similarity:
                #     out[i][model_type]["similarity"] = get_cosine_similarity(representations[1:-1]) # remove first and last layers
                #     # timestamp("similarity complete")

                # if self.similarity_norm:
                #     out[i][model_type]["similarity_norm"] = get_pairwise_similarity(representations[1:-1])
                #     # timestamp("similarity norm complete")

                # if self.rep_cs:
                #     out[i][model_type]["rep_cs"] = get_rep_cs(representations[1:-1])

                # if self.rep_class_cs:
                #     out[i][model_type]["rep_class_cs"] = get_rep_class_cs(model, tokenizer, input_ids, representations[1:-1])
                
                if self.alignment:
                    input_ids_cuda = input_ids.to('cuda')
                    ratio_mat, frob_mat = compute_alignment(model, input_ids_cuda)
                    out[i][model_type]["alignment_ratio"] = ratio_mat
                    out[i][model_type]["alignment_frob"] = frob_mat
                # timestamp("Metric collection complete")

                timestamp(f"Ended {model_type}")

        outputs["OUT"] = out

        if save:
            out_file = self._format_out_dir(data_name, start=start, end=end)
            torch.save(outputs, out_file)

        return outputs, out_file

    def _format_out_dir(self, data_name, start=None, end=None):
        out_file = data_name + "_" + str(start) + "_" + str(end)
        
        if self.alignment:
            out_file = 'alignment_' + out_file

        out_file += "_"
        out_file += '_'.join(self.model_types)
        out_file += ".pt"
        
        experiment_out_dir = os.path.join(self.out_dir, self.model_name, data_name)

        if not os.path.exists(os.path.join(self.out_dir, self.model_name)):
            os.mkdir(os.path.join(self.out_dir, self.model_name))

        if not os.path.exists(experiment_out_dir):
            os.mkdir(experiment_out_dir)

        out_file = os.path.join(experiment_out_dir, out_file)

        return out_file

    # def _get_prompt_outputs(self, prompt):
    #     return get_prompt_outputs(self.model, self.tokenizer, prompt)

class ExperimentAlign:

    def __init__(self,
                models,
                model_types,
                tokenizer,
                model_name,
                out_dir="out",
                alignment           = True
                ):
        """

        :param model:
        :param tokenizer:

        :param model_name:      name of model for outputs
        :param out_dir:         output directory

        :param lss:             True/False to include/exclude
        :param equidistance:    True/False to include/exclude
        :param norm:            True/False to include/exclude
        :param head_entropy:    True/False to include/exclude
        """
        self.models          = models
        self.model_types     = model_types
        self.tokenizer       = tokenizer
        self.model_name      = model_name
        self.out_dir         = out_dir

        self.alignment       = alignment

    def run(self, dataset, data_name, start=None, end=None, save=True):

        if data_name == 'gsm8k' and end == 5:
            skip_4 = True
            end = 6
        else:
            skip_4 = False

        prompts, start, end = get_prompts(dataset, data_name, start=start, end=end)

        tokenizer = self.tokenizer
        model_types = self.model_types
        models = self.models


        outputs = {
            "DATA_NAME": data_name
        }

        out = {}


        for i, prompt in zip(range(start, end), prompts):
            if skip_4 and i == 4:
                continue
            timestamp(f"Running prompt {i + 1} of {end}")
            out[i] = {}
            print(prompt)

            tokens = tokenizer(prompt, return_tensors='pt')
            input_ids = tokens.input_ids
            num_tokens = input_ids.shape[1]
            print(num_tokens)
            chunks = 2 * (num_tokens // 20) + 5 + i
            
            print("Number of chunks:", chunks)# timestamp(f"Tokens loaded")

            for idx in range(len(models)):
                model_type = model_types[idx]
                model = models[idx]
                out[i][model_type] = {}


                if self.alignment:
                    input_ids_cuda = input_ids.to('cuda')
                    ratio_mat, frob_mat, frob_mat_sing, uu_mat, vv_mat, uv_mat, vu_mat= compute_alignment(model, input_ids_cuda, chunks)
                    
                    out[i][model_type]["alignment_ratio"] = ratio_mat
                    out[i][model_type]["alignment_frob"] = frob_mat
                    out[i][model_type]["alignment_frob_sing"] = frob_mat_sing
                    
                    out[i][model_type]["alignment_uu"] = uu_mat
                    out[i][model_type]["alignment_vv"] = vv_mat
                    out[i][model_type]["alignment_uv"] = uv_mat
                    out[i][model_type]["alignment_vu"] = vu_mat
                # timestamp("Metric collection complete")

                timestamp(f"Ended {model_type}")

        outputs["OUT"] = out

        if save:
            out_file = self._format_out_dir(data_name, model_type, start=start, end=end)
            torch.save(outputs, out_file)

        return outputs, out_file

    def _format_out_dir(self, data_name, model_type, start=None, end=None):
        out_file = data_name + "_" + model_type + "_" + str(start) + "_" + str(end) + ".pt"

        if self.alignment:
            out_file = 'alignment_' + out_file

        experiment_out_dir = os.path.join(self.out_dir, self.model_name, data_name)

        if not os.path.exists(os.path.join(self.out_dir, self.model_name)):
            os.mkdir(os.path.join(self.out_dir, self.model_name))

        if not os.path.exists(experiment_out_dir):
            os.mkdir(experiment_out_dir)

        out_file = os.path.join(experiment_out_dir, out_file)

        return out_file

