import numpy as np
import os
import sys
import time
from time import perf_counter as tpc
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import warnings


# To remove the warning of torch:
warnings.filterwarnings('ignore', category=UserWarning)


from config import config
from data import Data
from model_gpt import ModelGPT
from model_gpt_cp import ModelGPTCP
from model_gpt_lr2 import ModelGPTLR2
from model_gpt_mt import ModelGPTMT
from model_gpt_cp_refactor_attempt import ModelGPTCPSparse
# from model_gpt_tt import ModelGPTTT
from model_gpt_rk1base import ModelGPTRK1B
from trainer import Trainer
from utils import init
from utils import init_log
from utils import init_path
from utils import log
from utils import plot_loss




class Llmtelora:
    def __init__(self, args, dev=False):
        self.args = args
        self.dev = dev

        self.init()
        if not self.dev:
            self.init_writer()

    def init(self):
        self.fold = f'result/{self.args.name}'
        self.file_log = f'{self.fold}/log.txt'

        init(self.args)
        if not self.dev:
            init_path(self.args.name, root='result', rewrite=self.args.rewrite)
            init_log(self.args, self.file_log)

        self.epoch = 0         # Current epoch of learning
        self.times = []        # Total worktimes for each epoch
        self.losses_trn = []   # Loss on train data for each epoch
        self.losses_tst = []   # Loss on test data for each epoch
        self.current_step = 0

    def init_writer(self):
        self.writer = SummaryWriter(f'{self.fold}/tensorboard')

        params = '\n'.join([f'|{key}|{value}|'
            for key, value in vars(self.args).items()])
        self.writer.add_text(
            'Config', '|param|value|\n|-|-|\n%s' % params)

    def load(self, model_name='model', model_class=None):
        fpath = f'result/{self.args.name}/{model_name}.pt'
        params = torch.load(fpath, map_location=torch.device(self.args.device))
        self.set_model(params, model_class=model_class)

        self.log('\n+++ Model is loaded from the file\n')

        fpath = os.path.join(self.fold, 'result.npz')
        data = np.load(fpath, allow_pickle=True).get('data').item()
        self.epoch = data['epoch']
        self.times = data['times']
        self.losses_trn = data['losses_trn']
        self.losses_tst = data['losses_tst']

        self.log('\n+++ Data are loaded from the file\n')

    def log(self, text, kind=''):
        if not self.dev:
            log(text, kind, self.file_log)

    def log_writer_current(self):
        self.writer.add_scalar('losses/loss_trn_epoch_last',
            self.loss, self.iter)

    def log_writer_epoch(self):
        self.writer.add_scalar('losses/loss_trn',
            self.losses_trn[-1], self.epoch)
        self.writer.add_scalar('losses/loss_tst',
            self.losses_tst[-1], self.epoch)
        self.writer.add_scalar('times/time_epoch',
            self.times[-1], self.epoch)
        self.writer.add_scalar('times/time_total',
            np.sum(self.times), self.epoch)

    def run(self):
        self.log('Start training...\n\n')

        for self.epoch in range(1, self.args.epochs+1):
            _t = tpc()
            self._run_trn()
            self._run_tst()
            self.trainer.epoch()
            self.times.append(tpc() - _t)
            self.log_writer_epoch()

        plot_loss(self.losses_trn, os.path.join(self.fold, f'loss_trn.png'))
        plot_loss(self.losses_tst, os.path.join(self.fold, f'loss_tst.png'))

        self.writer.close()

        text = ''
        text += '-' * 43 + '\n'
        text += f'DONE  |     TIME {np.sum(self.times):-8.2e} sec (total)\n'
        text += '-' * 43 + '\n'
        self.log(text)

    def save(self, model_name='model', save_model=True, save_result=True):
        os.makedirs(f'{self.fold}', exist_ok=True)
        
        if save_model:
            torch.save(self.model.state_dict(), f'{self.fold}/{model_name}.pt')
            self.log('\n+++ Model is saved to the file\n')

        if save_result:
            fpath = os.path.join(self.fold, 'result.npz')
            np.savez_compressed(fpath, data={
                'config': {n: v for n, v in vars(self.args).items()
                    if isinstance(v, (bool, int, float, str))},
                'epoch': self.epoch,
                'times': self.times,
                'losses_trn': self.losses_trn,
                'losses_tst': self.losses_tst})
            self.log('\n+++ Results are saved to the file\n')

    def set_data(self):
        self.data = Data(
            batch_trn=self.args.batch_trn,
            batch_tst=self.args.batch_tst,
            block_size=self.args.block_size,
            d=None if self.args.mode in ['bs'] else self.args.d)

    def set_model(self, params=None, model_class=None):
        if model_class is not None:
            # Special case for dev-mode:
            Model = model_class
        else:
            match self.args.mode:
                case 'bs':
                    Model = ModelGPT
                case 'cp':
                    Model = ModelGPTCP
                case 'cp_sparse':
                    Model = ModelGPTCPSparse
                case 'lr2':
                    Model = ModelGPTLR2
                case 'mt':
                    # TODO: replace it later with ModelGPTMT
                     Model = ModelGPTRK1B
                case 'tt':
                    # TODO:
                    raise NotImplementedError()
                    Model = ModelGPTTT
                case _:
                    raise NotImplementedError(f'Unknown mode for computation')

        self.model = Model(self, build_nanogpt_config(self.args))
        if params is not None:
            self.model.load_state_dict(params)

        p = self.model.get_num_params() / 1.E+6
        text = f'Model is created (params {p:.1f} M)\n'
        self.log(text)

    def set_trainer(self):
        self.trainer = Trainer(self.args)
        self.trainer.init(self.model)

    def show_demo(self, output_length=200):
        text_inp = self.data.encode(self.args.prompt_demo)
        #output = self.model.generate(text_inp, output_length, speculative=False)
        output = self.model.generate_w_speculative(text_inp, output_length)
        text_out = self.data.decode(output)

        text = f'\n\n----\nDemo      for generate method\n'
        text += f'Input  : {self.args.prompt_demo}\n'
        text += f'Output : {text_out}\n'
        self.log(text)

    def _run_trn(self):
        """One epoch of model training."""
        _t0 = tpc()
        loss_full = 0.
        
        self.iter = 0
        self.loss = None
        self.loss_trn_epoch = []

        desc = f'Train # {self.epoch:-2d}'
        steps = self.data.size_trn
        tqdm_ = tqdm(desc=desc, unit='step', total=steps, file=sys.stdout)

        for data_x, data_y in self.data.loader_trn:
            loss = self.trainer(self.model, data_x, data_y)

            self.iter += 1
            self.loss = loss
            self.loss_trn_epoch.append(self.loss)

            self.log_writer_current()

            loss_full += loss
            text = f'Loss: {loss:-8.2e}'
            tqdm_.set_postfix_str(text, refresh=True)
            tqdm_.update(1)
            self.current_step += 1

        loss_full /= self.data.size_trn

        tqdm_.close()    
        time.sleep(0.00001)

        self.losses_trn.append(loss_full)

        text = ''
        text += f'TRN   # {self.epoch:-2d}> '
        text += f'TIME {tpc() - _t0:-8.2e} | '
        text += f'LOSS TOTAL {loss_full:-8.2e} | '
        self.log(text)

    def _run_tst(self):
        """Check the accuracy of the model on the test data set."""
        _t0 = tpc()
        loss_full = 0.

        for data_x, data_y in self.data.loader_tst:
            loss = self.trainer.test(self.model, data_x, data_y)
            loss_full += loss

        loss_full /= self.data.size_tst

        self.losses_tst.append(loss_full)

        text = ''
        text += f'TST   # {self.epoch:-2d}> '
        text += f'TIME {tpc() - _t0:-8.2e} | '
        text += f'LOSS TOTAL {loss_full:-8.2e} '
        self.log(text)


def parse_var(s):
    """
    Parse a key, value pair, separated by '='
    That's the reverse of ShellArgs.

    On the command line (argparse) a declaration will typically look like:
        foo=hello
    or
        foo="hello world"
    """
    items = s.split('=')
    key = items[0].strip() # we remove blanks around keys, as is logical
    if len(items) > 1:
        # rejoin the rest:
        value = '='.join(items[1:])
    return (key, value)


def parse_vars(items):
    """
    Parse a series of key-value pairs and return a dictionary
    """
    d = {}

    if items:
        for item in items:
            key, value = parse_var(item)
            d[key] = value
    return d

def build_nanogpt_config(args):
    class Config:
        block_size = args.block_size
        vocab_size = args.vocab_size
        n_layer = args.n_layer
        n_head = args.n_head
        n_embd = args.n_embd
        dropout = args.dropout
        bias = False
        d = args.d
        r = args.r
        
    return Config()


if __name__ == '__main__':
    args = config()
    manager = Llmtelora(args)
    manager.set_data()

    if args.from_pretrained:
        params = torch.load(args.from_pretrained, map_location=torch.device(args.device))
        print('Params are loaded from pretrained')
    else:
        params = None

    manager.set_model(params)
    manager.set_trainer()
    manager.run()
    manager.save(save_model=args.save_model)
    manager.show_demo()