import os, random, argparse, sys, pickle, time
import torch
from tqdm import tqdm, trange
import numpy as np
import pandas as pd
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from datasets import Dataset 
from torch.utils.data import DataLoader
import wandb
from dataclasses import dataclass, field

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
"""
This code is designed for alignment search
for large models, i.e., >1B parameters.

We test it out with Alpaca 7B which is based
on LLaMA 7B model, but it should be extensible
to larger models as well if computation resource
is allowed.
"""
CACHE_DIR = "../.cache/"

class AlpacaAligner(object):
    def __init__(
        self, model,
        is_master,
        logger,
        args,
        lr=5e-5,
        apex_enable=False,
        n_gpu=1,
        gpu_id=0,
        early_stopping=5,
        do_statistic=False,
        model_name="",
        device="cuda"
    ):
        self.model = model
        num_params = count_parameters(model)
        logger.info(f'Number of Alpaca-7B model params: {num_params}') 
        self.is_master = is_master
        self.logger = logger
        self.is_wandb = args.is_wandb
        self.model_name = model_name
        
        self.lr = lr
        self.n_gpu = n_gpu
        self.device = device
        
        self.early_stopping = early_stopping
        
        if args.is_wandb and is_master:
            import wandb
            run = wandb.init(
                project=f"Boundless-DAS-{args.task_name}", 
                entity=args.wandb_username,
                name=model_name,
            )
            wandb.config.update(args)
    
    def save_model(self, output_dir, model_name):
        if self.n_gpu > 1:
            torch.save({
                'rotate_layer': self.model.module.model.rotate_layer.state_dict(),
                'intervention_boundaries': self.model.module.model.intervention_boundaries,
                'temperature': self.model.module.model.temperature
            }, os.path.join(output_dir, model_name))
        else:
            torch.save({
                'rotate_layer': self.model.model.rotate_layer.state_dict(),
                'intervention_boundaries': self.model.model.intervention_boundaries,
                'temperature': self.model.model.temperature
                
            }, os.path.join(output_dir, model_name))
    
    def prealign_eval(self, prealign_dataloader, output_dir):
        total_count = 0
        correct_count = 0
        self.model.eval()
        with torch.no_grad():
            for step, inputs in enumerate(prealign_dataloader):
                for k, v in inputs.items():
                    if v is not None and isinstance(v, torch.Tensor):
                        inputs[k] = v.to(self.device)

                # aligning forward!
                outputs = self.model(
                    input_ids=inputs['input_ids'],
                    labels=inputs['labels']
                )

                actual_test_labels = inputs['labels'][:, -1]
                pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)
                correct_labels = (actual_test_labels==pred_test_labels)
                
                total_count += len(correct_labels)
                correct_count += correct_labels.sum().tolist()
        current_acc = round(correct_count/total_count, 2)
        logger.info(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}")
        
        if self.is_master and not self.is_wandb:
            log_prealign = open(os.path.join(output_dir, 'prealign_log.txt'), 'w', buffering=1)
            print(f'prealign_accuracy,{current_acc}', file=log_prealign)
            log_prealign.close()
        elif self.is_wandb:
            wandb.log(
                {
                    "eval/prealign_accuracy": current_acc
                },
                step=0
            )
            
    def train(
        self, train_dataloader, dev_dataloader, test_dataloader,
        optimizer, scheduler, output_dir,
        log_step, valid_steps, epochs, 
        gradient_accumulation_steps,
    ):
        if self.is_master and not self.is_wandb:
            log_train = open(os.path.join(output_dir, 'train_log.txt'), 'w', buffering=1)
            log_eval = open(os.path.join(output_dir, 'eval_log.txt'), 'w', buffering=1)
            print('step,loss,accuracy', file=log_train)
            print('step,accuracy', file=log_eval)
            log_train.close()
            log_eval.close()

        # okay, have to honest, not sure whether we do train mode align or eval align;
        # i guess it is good to try both, but ... only trying train here and move on.
        self.model.train()
        train_iterator = trange(
            0, int(epochs), desc="Epoch"
        )
        total_step = 0
        total_log_step = 0
        best_eval_acc = -1
        target_total_step = len(train_dataloader) * int(epochs)
        temperature_start = 50.0
        temperature_end = 0.1
        temperature_schedule = torch.linspace(temperature_start, temperature_end, target_total_step).to(torch.bfloat16)
        self.model.model.temperature.data = temperature_schedule[total_step]
        
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True)
            for step, inputs in enumerate(epoch_iterator):
                
                
                for k, v in inputs.items():
                    if v is not None and isinstance(v, torch.Tensor):
                        inputs[k] = v.to(self.device)
                
                # aligning forward!
                source_hidden_states = self.model(
                   input_ids=inputs['source_input_ids'],
                   output_rotated_hidden_states_only=True
                ).rotated_hidden_states
                
                outputs = self.model(
                    input_ids=inputs['input_ids'],
                    source_hidden_states=source_hidden_states,
                    intervention_ids=inputs['intervention_ids'],
                    labels=inputs['labels']
                )
                
                loss = outputs.loss.mean() if self.n_gpu > 1 else outputs.loss
                
                actual_test_labels = inputs['labels'][:, -1]
                pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)
                correct_labels = (actual_test_labels==pred_test_labels)
                step_accuracy = correct_labels.sum() / correct_labels.shape[0]
                step_accuracy = step_accuracy.tolist()

                if self.is_master and total_step % log_step == 0:
                    if self.is_wandb:
                        intervention_boundaries = torch.clamp(self.model.model.intervention_boundaries, 1e-3, 1)
                        wandb.log(
                            {
                                "train/loss": loss.item(),
                                "train/step_accuracy": step_accuracy,
                                "train/temperature": self.model.model.temperature.data,
                                "train/unified_boundary": intervention_boundaries.data[0],
                                "train/unified_boundary (dummy)": intervention_boundaries.data[1],                                       
                            },
                            step=total_step
                        )
                    else:
                        log_train = open(os.path.join(output_dir, 'train_log.txt'), 'a', buffering=1)
                        print('{},{},{}'.format(
                                total_step, loss.item(), step_accuracy
                            ),
                            file=log_train
                        )
                        log_train.close()
                        
                    if total_step != 0 and total_step % valid_steps == 0:
                        total_count = 0
                        correct_count = 0
                        self.model.eval()
                        with torch.no_grad():
                            for step, inputs in enumerate(dev_dataloader):
                                for k, v in inputs.items():
                                    if v is not None and isinstance(v, torch.Tensor):
                                        inputs[k] = v.to(self.device)

                                # aligning forward!
                                source_hidden_states = self.model(
                                    input_ids=inputs['source_input_ids'],
                                    output_rotated_hidden_states_only=True
                                ).rotated_hidden_states
                                outputs = self.model(
                                    input_ids=inputs['input_ids'],
                                    source_hidden_states=source_hidden_states,
                                    intervention_ids=inputs['intervention_ids'],
                                    labels=inputs['labels']
                                )

                                actual_test_labels = inputs['labels'][:, -1]
                                pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)
                                correct_labels = (actual_test_labels==pred_test_labels)

                                total_count += len(correct_labels)
                                correct_count += correct_labels.sum().tolist()

                        current_acc = round(correct_count/total_count, 2)
                        if self.is_wandb:
                            wandb.log(
                                {
                                    "eval/accuracy": current_acc
                                },
                                step=total_step
                            )
                        else:
                            log_eval = open(os.path.join(output_dir, 'eval_log.txt'), 'a', buffering=1)
                            print('{},{}'.format(total_step, current_acc), file=log_eval)
                            log_eval.close()
                            
                        if current_acc > best_eval_acc:
                            best_eval_acc = current_acc
                            if self.is_master:
                                self.save_model(output_dir, 'pytorch-rotate-best.bin')
                        self.model.train()

                    total_log_step += 1
                loss_str = round(loss.item(), 2)
                epoch_iterator.set_postfix({'loss': loss_str})
                
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps
                
                if total_step % gradient_accumulation_steps == 0:
                    if not (gradient_accumulation_steps > 1 and total_step == 0):
                        loss.backward()
                        optimizer.step()
                        scheduler.step()
                        self.model.zero_grad()
                        self.model.model.temperature.data = temperature_schedule[total_step]
                    
                total_step += 1
                
        logger.info("Training is finished ...") 
        
        ###############################
        # End of training evaluation.
        if self.is_master:
            total_count = 0
            correct_count = 0
            self.model.eval()
            with torch.no_grad():
                for step, inputs in enumerate(test_dataloader):
                    for k, v in inputs.items():
                        if v is not None and isinstance(v, torch.Tensor):
                            inputs[k] = v.to(self.device)

                    # aligning forward!
                    source_hidden_states = self.model(
                        input_ids=inputs['source_input_ids'],
                        output_rotated_hidden_states_only=True
                    ).rotated_hidden_states
                    outputs = self.model(
                        input_ids=inputs['input_ids'],
                        source_hidden_states=source_hidden_states,
                        intervention_ids=inputs['intervention_ids'],
                        labels=inputs['labels']
                    )

                    actual_test_labels = inputs['labels'][:, -1]
                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)
                    correct_labels = (actual_test_labels==pred_test_labels)

                    total_count += len(correct_labels)
                    correct_count += correct_labels.sum().tolist()

            current_acc = round(correct_count/total_count, 2)
            if self.is_wandb:
                wandb.log(
                    {
                        "test/accuracy": current_acc
                    },
                    step=total_step
                )
                wandb.finish()
            else:
                log_eval = open(os.path.join(output_dir, 'eval_log.txt'), 'a', buffering=1)
                print('{},{}'.format(total_step, current_acc), file=log_eval)
                log_eval.close()
        ###############################
        
        if self.is_master:
            self.save_model(output_dir, 'pytorch-rotate-last.bin')

        