import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline, TfidfPipeline
from embodied_cd.trl.algos.core import stack_dicts, stats_to_np, stats_print
from embodied_cd.trl.models.core import generation
from embodied_cd.trl.ecoc.ecoc_trainer import custom_collate, ECoCTrainer


class ThinkThreeStepTrainer(ECoCTrainer):
    def __init__(self, env_name, pre_model_name, base_tokenizer, base_model ,tokenizer, model, dataset, output_dir, **params):
        super().__init__(env_name, pre_model_name, base_tokenizer, base_model, tokenizer, model, dataset, output_dir, **params)

    @classmethod
    def _setup_base_model(cls, pre_model_name, tokenizer, model, dataset, max_think_token):
        # reward pipe
        cossim_pipe = SentenceSimilarityPipeline()
        tfidf_pipe = TfidfPipeline()
        
        cossim_score, tfidf_score, count = 0., 0., 0 
        for i, data in enumerate(tqdm.tqdm(dataset, desc="setup base")):
            instruction, state, history, think = data["instruction"], PromptTemplate.preprocess(data["state"]), data["history"], data["think_list"]
            think_list = [think[0] + " " + think[1] + " " + think[2], think[3] + " " + think[4]]

            # 1. Generate Initial Resposne
            model.set_adapter('reasoning_policy')
            init_response_list = []
            for j in range(len(think_list)):
                _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts_step_3[j]}"
                if 'instruct' in pre_model_name or 'Instruct' in pre_model_name:
                    query = [{"role": "user", "content": _query}]
                    query_id = tokenizer.apply_chat_template(
                        query, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
                else:
                    query = f"### Human: {_query}\n### Assistant:"
                    query_id = tokenizer.encode(query, return_tensors="pt").to(model.device)

                with torch.no_grad():
                    response_text = cls.generate(
                        cls, model, tokenizer, query_id, cls.gen1_params, max_think_token)

                # process response text 
                response_text = response_text.strip().split("\n")[0]
                response_text = response_text.split(". ")

                if j == 0:
                    if len(response_text) > 2:
                        response_text = response_text[0] + ". " + response_text[1] + ". " + response_text[2]
                    else:
                        response_text = response_text[0]
                elif j == 1:
                    if len(response_text) > 1:
                        response_text = response_text[0] + ". " + response_text[1] 
                    else:
                        response_text = response_text[0]

                response_text = response_text + "." if response_text[-1] != '.' else response_text
                init_response_list.append(response_text)
            init_response = " ".join(init_response_list)

            print_error(init_response)
            cossim_score += cossim_pipe(think_list, init_response_list)
            tfidf_score += tfidf_pipe(think_list, init_response_list)
            count += 1
        
        return 0.5 * cossim_score / count  + 0.5 * tfidf_score / count

    def _setup_init_response(self):
        self.init_response = {}
        # 0. Setting Seed Rationale Model
        if self.base_model is None:
            self.model.set_adapter('reasoning_policy')
            tokenizer, model = self.tokenizer, self.model
        else:
            self.base_model.set_adapter('reasoning_policy')
            tokenizer, model = self.base_tokenizer, self.base_model

        for data in tqdm.tqdm(self.dataset, desc="init batch"): #iterate over dataset
            instruction, state, history = data["instruction"], PromptTemplate.preprocess(data["state"]), data["history"]

            response_list = []
            for j in range(self.params["total_think"]):
                _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts_step_3[j]}"
                if 'instruct' in self.pre_model_name or 'Instruct' in self.pre_model_name:
                    query = [{"role": "user", "content": _query}]
                    query_id = tokenizer.apply_chat_template(
                        query, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.device)
                else:
                    query = f"### Human: {_query}\n### Assistant:"
                    query_id = tokenizer.encode(query, return_tensors="pt").to(model.device)
                
                with torch.no_grad():
                    response_text = self.generate(
                        model, tokenizer, query_id, self.gen1_params, self.params["max_think_token"])

                # process response text 
                response_text = response_text.strip().split("\n")[0]
                response_text = response_text.split(". ")
                if j == 0:
                    if len(response_text) > 2:
                        response_text = response_text[0] + ". " + response_text[1] + ". " + response_text[2]
                    else:
                        response_text = response_text[0]
                elif j == 1:
                    if len(response_text) > 1:
                        response_text = response_text[0] + ". " + response_text[1]
                    else:
                        response_text = response_text[0]

                try:
                    response_text = response_text + "." if response_text[-1] != '.' else response_text
                except:
                    pass

                response_list.append(response_text)
            self.init_response[data["index"]] = response_list

        self.neg_response = {}
        self.pos_response = {}
        for data in self.dataset:
            self.neg_response[data["index"]] = []
            self.pos_response[data["index"]] = []

    def train(self):
        # enable adapter layers
        self.prev_neg_percent = 1.0

        for epoch in tqdm.tqdm(range(1, self.params["total_epochs"]+1), desc="epoch"): # iterate over epochs
            ####################### process optimize
            for batch in self.dataloader: # iterate over dataloader
                all_stats = []
                for inner_epoch in range(self.params["inner_epochs"]):
                    for i in range(self.params["batch_size"]):
                        instruction, state, action, history, idx, think, think_cat = \
                            batch["instructions"][i], \
                            PromptTemplate.preprocess(batch["states"][i]), \
                            batch["actions"][i], batch["histories"][i], \
                            batch["indexes"][i], batch["think_lists"][i], batch["thinks"][i]

                        """ 1: Rationale Optimize """
                        self.model.set_adapter('reasoning_policy')
                        loss_r, loss_fr, loss_brb, loss_brs = \
                            torch.tensor(0.).to(self.device), torch.tensor(0.).to(self.device), torch.tensor(0.).to(self.device), torch.tensor(0.).to(self.device)
                        ########################################

                        """ 1-1: (Forward) Reasoning Optimize """
                        ########################################
                        reasoning_trace = ""
                        for j in range(self.params["total_think"]):
                            if j == 0: think_segment = think[0] + " " + think[1] + " " + think[2]
                            if j == 1: think_segment = think[3] + " " + think[4]
                            
                            _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts_step_3[j]}"
                            query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, think_segment)

                            # run model
                            model_output = self.model(cquery_id, labels=cquery_label)
                            loss_fr += model_output.loss / self.params["total_think"]

                        self.optimizer_r.zero_grad()
                        loss_fr.backward()
                        self.optimizer_r.step()

                        if self.params["ablation"] in [0, 2]: 
                            """ 1-2: (Backward) Base-Correction Reasoning Optimize """
                            ########################################
                            init_response = self.init_response[idx]
                            reasoning_trace = ""
                            for j in range(self.params["total_think"]):
                                if j == 0: think_segment = think[0] + " " + think[1] + " " + think[2]
                                if j == 1: think_segment = think[3] + " " + think[4]
                                score, score_str = self.get_score_str(init_response[j], think_segment)

                                _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nReasoning Trace: {reasoning_trace}\nThink: {init_response[j]}\n{score_str} {PromptTemplate.correct_think_prompts_step_3[j]}"
                                query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, think_segment)

                                # run Model
                                model_output = self.model(cquery_id, labels=cquery_label)
                                loss_brb += model_output.loss / self.params["total_think"]
                                reasoning_trace = think_segment if j == 0 else reasoning_trace + " " + think_segment
                             
                            self.optimizer_r.zero_grad()
                            loss_brb.backward()
                            self.optimizer_r.step()

                        loss_r = loss_fr + loss_brb + loss_brs

                        
                        """ Base-2: Action Optimize """
                        self.model.set_adapter('planning_policy')
                        loss_p = 0.
                        ########################################
                        reasoning_trace = ""
                        for j in range(self.params["total_think"]+1):
                            _query = self.plan_template(instruction=instruction, state=state, think=reasoning_trace, history=history)["query"]
                            query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, action)

                            # run model
                            model_output = self.model(cquery_id, labels=cquery_label)
                            loss_p += model_output.loss / (self.params["total_think"] + 1)
                            if j < self.params["total_think"]:
                                if j == 0: think_segment = think[0] + " " + think[1] + " " + think[2]
                                if j == 1: think_segment = think[3] + " " + think[4]
                                reasoning_trace = think_segment if j == 0 else reasoning_trace + " " + think_segment

                        self.optimizer_p.zero_grad()
                        loss_p.backward()
                        self.optimizer_p.step()

                        all_stats.append({
                            "loss/reasoning_policy": loss_r,
                            "loss/reasoning_policy/foward": loss_fr,
                            "loss/reasoning_policy/backward_base": loss_brb,
                            "loss/planning_policy": loss_p,
                        })

            ######################  logging !!!!
            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            stats_print(stats)

            ######################  exploration !!
            count = None
            if epoch != self.params["total_epochs"]: 
                if self.params["ablation"] in [0, 2, 3]:
                    self._setup_init_response()

            if count is not None:
                self.prev_neg_percent = count / (len(self.dataset) * self.params["total_think"])
            else:
                self.prev_neg_percent = 0.0
            stats["explore/neg_percent"] = self.prev_neg_percent
            print_warn(f"explore/neg_precent: {self.prev_neg_percent}")
            wandb.log(stats)
