import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline, TfidfPipeline
from embodied_cd.trl.algos.core import stack_dicts, stats_to_np, stats_print
from embodied_cd.trl.models.core import generation
from embodied_cd.trl.ecoc.ecoc_trainer import custom_collate, ECoCTrainer


class FeedbackTrainer(ECoCTrainer):
    def __init__(self, env_name, pre_model_name, base_tokenizer, base_model, curr_tokenizer, curr_model, tokenizer, model, dataset, output_dir, **params):
        self.env_name = env_name
        self.pre_model_name = pre_model_name

        # set params
        self.params = self.default_params
        self.params.update(params)
        print_warn(f"Max Think Tokens: {params['max_think_token']}")
        print_warn(f"Total Think: {params['total_think']}")
        
        self.base_tokenizer = base_tokenizer
        self.base_model = base_model
        self.curr_tokenizer = curr_tokenizer
        self.curr_model = curr_model

        self.tokenizer = tokenizer
        self.model = model
        self.device = model.device
        
        self.dataset = dataset
        self.output_dir = output_dir
        
        # setup
        self._setup_model()

    def _setup_model(self) -> None:
        # create dataloader
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.params["batch_size"],
            shuffle=False,
            collate_fn=custom_collate,
            drop_last=True,
        )

        # create optimizer
        self.optimizer = torch.optim.AdamW( # rationale
            self.model.parameters(),
            lr=self.params["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=0.01,
        )

        # reward pipe
        self.cossim_pipe = SentenceSimilarityPipeline()
    
    def train(self):
        for epoch in tqdm.tqdm(range(1, self.params["total_epochs"]+1), desc="epoch"): # iterate over epochs
            self.explore()

            ####################### process optimize
            for batch in self.dataloader: # iterate over dataloader
                all_stats = []
                for inner_epoch in range(self.params["inner_epochs"]):
                    for i in range(self.params["batch_size"]):
                        instruction, state, action, history, idx, think, think_cat = \
                            batch["instructions"][i], \
                            PromptTemplate.preprocess(batch["states"][i]), \
                            batch["actions"][i], batch["histories"][i], \
                            batch["indexes"][i], batch["think_lists"][i], batch["thinks"][i]
                    
                        loss_base = 0.
                        base_response = self.base_response[idx]
                        for j in range(self.params["total_think"]):
                            _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}\nRationale: {base_response[j][0]}"
                            query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, base_response[j][1])

                            # run model
                            model_output = self.model(cquery_id, labels=cquery_label)
                            loss_base += model_output.loss / self.params["total_think"]

                        loss_curr = 0.
                        curr_response = self.curr_response[idx]
                        for j in range(self.params["total_think"]):
                            _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}\nRationale: {curr_response[j][0]}"
                            query, query_id, cquery, cquery_id, cquery_label = self.querize(_query, curr_response[j][1])

                            # run model
                            model_output = self.model(cquery_id, labels=cquery_label)
                            loss_curr += model_output.loss / self.params["total_think"]

                        loss = loss_base + loss_curr
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()

                        all_stats.append({
                            "loss/base": loss_base,
                            "loss/curr": loss_curr,
                            "loss/total_loss": loss,
                        })
                                               
            ######################  logging !!!!
            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            stats_print(stats)
                        
    def explore(self):
        self.base_model.set_adapter('reasoning_policy')
        self.curr_model.set_adapter('reasoning_policy')
        
        self.base_response, self.curr_response = {}, {}
        for data in tqdm.tqdm(self.dataset, desc="init batch"): #iterate over dataset
            instruction, state, history, think = \
                data["instruction"], PromptTemplate.randomize(data["state"], 1.0), data["history"], data["think_list"]
            # # generate base-model
            response_list = []
            for j in range(self.params["total_think"]):
                _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}"
                query = f"### Human: {_query}\n### Assistant:"
                query_id = self.base_tokenizer.encode(query, return_tensors="pt").to(self.device)

                with torch.no_grad():
                    response_text = self.generate(
                        self.base_model, self.base_tokenizer, query_id, self.gen2_params, self.params["max_think_token"])

                response_text = response_text.strip().split("\n")[0]
                response_text = response_text.split(". ")[0] 
                response_text = response_text + "." if response_text[-1] != '.' else response_text
                score, score_str = self.get_score_str(response_text, think[j])
                response_list.append((response_text, score_str))
                #print(response_text)
                #print(think[j])
                #print_warn(score_str)
            self.base_response[data["index"]] = response_list
            
            # # generate current-model
            response_list = []
            for j in range(self.params["total_think"]):
                _query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}"
                query = f"### Human: {_query}\n### Assistant:"
                query_id = self.curr_tokenizer.encode(query, return_tensors="pt").to(self.device)

                with torch.no_grad():
                    response_text = self.generate(
                        self.curr_model, self.curr_tokenizer, query_id, self.gen3_params, self.params["max_think_token"])

                response_text = response_text.strip().split("\n")[0]
                response_text = response_text.split(". ")[0] 
                response_text = response_text + "." if response_text[-1] != '.' else response_text
                score, score_str = self.get_score_str(response_text, think[j])
                response_list.append((response_text, score_str))
                #print(response_text)
                #print(think[j])
                #print_error(score_str)
            self.curr_response[data["index"]] = response_list
