import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline, TfidfPipeline
from embodied_cd.trl.algos.core import stack_dicts, stats_to_np, stats_print
from embodied_cd.trl.models.core import generation
from embodied_cd.trl.ecoc.ecoc_trainer import custom_collate, ECoCTrainer


class ThinkWholeTrainer(ECoCTrainer):
    def __init__(self, env_name, pre_model_name, base_tokenizer, base_model ,tokenizer, model, dataset, output_dir, **params):
        super().__init__(env_name, pre_model_name, base_tokenizer, base_model, tokenizer, model, dataset, output_dir, **params)

    @classmethod
    def _setup_base_model(cls, pre_model_name, tokenizer, model, dataset, max_think_token):
        # reward pipe
        cossim_pipe = SentenceSimilarityPipeline()
        tfidf_pipe = TfidfPipeline()
        
        cossim_score, tfidf_score, count = 0., 0., 0 
        for i, data in enumerate(tqdm.tqdm(dataset, desc="setup base")):
            instruction, state, history, think = data["instruction"], PromptTemplate.preprocess(data["state"]), data["history"], data["think"]

            # 1. Generate Initial Resposne
            model.set_adapter('reasoning_policy')

            _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompt}"

            if 'instruct' in pre_model_name or 'Instruct' in pre_model_name:
                query = [{"role": "user", "content": _query}]
                query_id = tokenizer.apply_chat_template(
                    query, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
            else:
                query = f"### Human: {_query}\n### Assistant:"
                query_id = tokenizer.encode(query, return_tensors="pt").to(model.device)

            with torch.no_grad():
                response_text = cls.generate(
                    cls, model, tokenizer, query_id, cls.gen1_params, max_think_token)
            response_text = response_text.strip().split("\n")[0]
            init_response = response_text
            
            print_error(init_response)
            cossim_score += cossim_pipe(think, response_text)
            tfidf_score += tfidf_pipe(think, response_text)
            count += 1 
        
        return 0.5 * cossim_score / count  + 0.5 * tfidf_score / count
    
    def _setup_init_response(self):
        self.init_response = {}
        # 0. Setting Seed Rationale Model
        if self.base_model is None:
            self.model.set_adapter('reasoning_policy')
            tokenizer, model = self.tokenizer, self.model
        else:
            self.base_model.set_adapter('reasoning_policy')
            tokenizer, model = self.base_tokenizer, self.base_model

        for data in tqdm.tqdm(self.dataset, desc="init batch"): #iterate over dataset
            instruction, state, history = data["instruction"], PromptTemplate.preprocess(data["state"]), data["history"]
            _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompt}"
            
            if 'instruct' in self.pre_model_name or 'Instruct' in self.pre_model_name:
                query = [{"role": "user", "content": _query}]
                query_id = tokenizer.apply_chat_template(
                    query, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.device)
            else:
                query = f"### Human: {_query}\n### Assistant:"
                query_id = tokenizer.encode(query, return_tensors="pt").to(model.device)

            with torch.no_grad():
                response_text = self.generate(
                    model, tokenizer, query_id, self.gen1_params, self.params["max_think_token"])
            response_text = response_text.strip().split("\n")[0]
            self.init_response[data["index"]] = response_text

    def train(self):
        # enable adapter layers
        self.prev_neg_percent = 1.0

        for epoch in tqdm.tqdm(range(1, self.params["total_epochs"]+1), desc="epoch"): # iterate over epochs
            ####################### process optimize
            for batch in self.dataloader: # iterate over dataloader
                all_stats = []
                for inner_epoch in range(self.params["inner_epochs"]):
                    for i in range(self.params["batch_size"]):
                        instruction, state, action, history, idx, think, think_cat = \
                            batch["instructions"][i], \
                            PromptTemplate.preprocess(batch["states"][i]), \
                            batch["actions"][i], batch["histories"][i], \
                            batch["indexes"][i], batch["think_lists"][i], batch["thinks"][i]
                        
                        """ 1: Rationale Optimize """
                        self.model.set_adapter('reasoning_policy')
                        loss_r, loss_fr, loss_brb = \
                            torch.tensor(0.).to(self.device), torch.tensor(0.).to(self.device), torch.tensor(0.).to(self.device)
                        
                        _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompt}"
                        query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, think_cat)
                        
                        # run model
                        model_output = self.model(cquery_id, labels=cquery_label)
                        loss_fr += model_output.loss

                        if self.params["ablation"] in [0, 2]: 
                            init_response = self.init_response[idx]
                            score, score_str = self.get_score_str(init_response, think_cat)
                            
                            _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {init_response}\n{score_str} {PromptTemplate.correct_think_prompt}"
                            query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, think_cat)

                            # run model
                            model_output = self.model(cquery_id, labels=cquery_label)
                            loss_brb += model_output.loss

                        loss_r = loss_fr + loss_brb
                        self.optimizer_r.zero_grad()
                        loss_r.backward()
                        self.optimizer_r.step() 

                        """ Base-2: Action Optimize """
                        self.model.set_adapter('planning_policy')
                        loss_p = 0.
                        ########################################
                        _query = self.plan_template(instruction=instruction, state=state, think=think_cat, history=history)["query"]
                        query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, action)
                        # run model
                        model_output = self.model(cquery_id, labels=cquery_label)
                        loss_p += model_output.loss / 2

                        _query = self.plan_template(instruction=instruction, state=state, think="", history=history)["query"]
                        query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, action)
                        # run model
                        model_output = self.model(cquery_id, labels=cquery_label)
                        loss_p += model_output.loss / 2

                        self.optimizer_p.zero_grad()
                        loss_p.backward()
                        self.optimizer_p.step()

                        all_stats.append({
                            "loss/reasoning_policy": loss_r,
                            "loss/reasoning_policy/foward": loss_fr,
                            "loss/reasoning_policy/backward_base": loss_brb,
                            "loss/planning_policy": loss_p,
                        })
    
            ######################  logging !!!!
            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            stats_print(stats)

            ######################  exploration !!
            count = None
            if epoch != self.params["total_epochs"]: 
                if self.params["ablation"] in [0, 2, 3]:
                    self._setup_init_response()

            if count is not None:
                self.prev_neg_percent = count / (len(self.dataset) * self.params["total_think"])
            else:
                self.prev_neg_percent = 0.0
            stats["explore/neg_percent"] = self.prev_neg_percent
            print_warn(f"explore/neg_precent: {self.prev_neg_percent}")
            wandb.log(stats)
