import os
import math
import time
import json
import pickle
import random
from typing import Optional, Union

import torch
import numpy as np

from .utils import visualize
from .utils import get_timestamp, histogram
from .trainer_settings import TrainerSettings

from .config import configurations 
from .trainer_settings import TrainerSettings
from .similarity_learner import ModelConfig, SimilarityLearner

class Trainer:
    """
    Usage example:
    from ml.trainer import Trainer
    t = Trainer("concepts/text-embedding-3-small", allow_save=False)
    t.kick_tires()
    t.train()
    """
    def __init__(self, configuration:Optional[Union[dict, str]]=None, **kwargs):
        if isinstance(configuration, str):
            if configuration in configurations:
                self.config = configurations[configuration]
            else:
                raise RuntimeError(f"Configuration '{configuration}' not found in:\n{configurations}")
        elif configuration == None:
            self.config = configurations["concepts/text-embedding-3-small"]
        else:
            self.config = configuration
        self.data = None

        # override existing config attributes with kwargs
        config_kwargs = {key: value for key, value in kwargs.items() if key in self.config}
        for key in config_kwargs:
            if isinstance(self.config[key], dict):
                self.config[key].update(config_kwargs[key])
            else:
                self.config[key] = config_kwargs[key]

        self.timestamp = get_timestamp()

        # settings specific to Trainer
        self.settings = TrainerSettings(**kwargs)
        self.device = torch.device(self.settings.device_type)

        self.load_data()
        self.reset_training()

    def load_data(self, train_split:float|int=0.8, verbose=True):
        """
          Loads the dataset from numpy dumps generated by ndarray.tofile
          https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tofile.html
        """
        data_files = self.config["data"]["dataset"]
        endpoint = self.config['data'].get('endpoint', 'a -> b')

        # data from multiple files will be joined by rows
        for data_file in data_files:
            with open(data_file, "rb") as data_file:
                if verbose:
                    print(f"Loading dataset from {data_file.name}")
                # load the data from the file
                # the file should contain a dictionary with keys 'x' and 'y'
                # where 'x' is a list of input sequences and 'y' is a list of target sequences
                # both are numpy arrays
                data = pickle.load(data_file)

            # shuffle data to allocate similar fraction of classes to train and validation sets
            random.shuffle(data['data'])
            train_split = int(len(data['data']) * train_split) if train_split <= 1 else train_split
            
            assert train_split and train_split < len(data['data']), f"train_split ({train_split}) must be a positive integer less than the size of the dataset ({len(data['data'])})"

            splits = {}
            splits['train'] = data['data'][:train_split]
            splits['val'] = data['data'][train_split:]

            append = False
            if not self.data:
                self.data = {
                    "train": {},
                    "val": {}
                }
            else:
                # assuming that a portion of the data has already been loaded
                append = True

            for split in splits:
                # get the input embeddings for the split
                for key in ['a', 'b']:
                    # stack embeddings by rows into a matrix
                    split_data = np.stack([self._get_embedding(pair[key], data) for pair in splits[split]], dtype=np.float32)
                    if append:
                        # append split_data to self.data[split][key]
                        self.data[split][key] = np.vstack([self.data[split][key], split_data])
                    else:
                        self.data[split][key] = split_data

                # get the endpoint (target) values for the split
                split_data = np.array([pair[endpoint] for pair in splits[split]], dtype=np.float32).reshape(-1, 1)
                if append:
                    self.data[split]['y'] = np.vstack([self.data[split]['y'], split_data])
                    # shuffle the data after appending
                    idx = np.arange(self.data[split]['y'].shape[0])
                    np.random.shuffle(idx)
                    self.data[split]['a'] = self.data[split]['a'][idx]
                    self.data[split]['b'] = self.data[split]['b'][idx]
                    self.data[split]['y'] = self.data[split]['y'][idx]
                else:
                    self.data[split]['y'] = split_data

                assert self.data[split]['a'].shape[0] == self.data[split]['b'].shape[0] == self.data[split]['y'].shape[0], \
                    f"split '{split}' has mismatched shapes: a {self.data[split]['a'].shape}, b {self.data[split]['b'].shape}, y {self.data[split]['y'].shape}"

        print(f"Size of dataset: train ({len(self.data['train']['y'])}), validation: ({len(self.data['val']['y'])})")

    def _get_embedding(self, key:str, data:dict) -> np.ndarray:
        key = key.replace('[', '').replace(']', '').strip()
        return data['embeddings'][data['index'][key]]

    def train(self, noplot=True):
        checkpoint = None
         # fetch the very first batch
        A, B, Y = self.get_batch()

        t0 = time.time()
        local_iter_num = 0 # number of iterations in the lifetime of this process
        raw_model = self.model.module if self.settings.ddp else self.model # unwrap DDP container if needed

        summary = []

        # dataset is a list of data files
        dataset_type = self.config["data"]["dataset_name"]
        embedding_model = self.config["data"]["embedding_model"]

        decay_lr = self.config["train"]["decay_lr"]
        learning_rate = self.config["train"]["learning_rate"]
        eval_interval = self.settings.eval_interval
        gradient_accumulation_steps = self.config["train"]["gradient_accumulation_steps"]
        grad_clip = self.config["train"]["grad_clip"]
        max_iters = self.config["train"]["max_iters"]

        ddp = self.settings.ddp
        out_dir = self.settings.out_dir
        early_stop = self.settings.early_stop

        print(f"Starting training at iteration {self.iter_num}")
        self.save_settings()

        no_improvement_counter = 0
        summary_data = []

        while True:
            # print average statistics of all rows (embeddings) in A, B
            if self.settings.master_process:
                for batch_key in ['A', 'B']:
                    text = [f"Batch {batch_key}:"]
                    for measure in ['min', 'max', 'median', 'mean', 'std']:
                        if measure in ['mean', 'std']:
                            text.append(f"{measure}: {getattr(eval(batch_key), measure)(dim=1).mean().item():.4f}")
                        else:
                            text.append(f"{measure}: {getattr(eval(batch_key), measure)(dim=1).values.mean().item():.4f}")
                    print(visualize.Color.GREEN + ", ".join(text) + visualize.Color.ENDC)
                print(visualize.Color.GREEN + f"Y histogram: {json.dumps(histogram(Y), indent=4)}" + visualize.Color.ENDC)

            # determine and set the learning rate for this iteration
            lr = self.get_lr() if decay_lr else learning_rate
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr

            if self.iter_num == 0 and self.settings.eval_only:
                break

            # forward backward update, with optional gradient accumulation to simulate larger batch size
            # and using the GradScaler if data type is float16
            for micro_step in range(gradient_accumulation_steps):
                if ddp:
                    # in DDP training we only need to sync gradients at the last micro step.
                    # the official way to do this is with model.no_sync() context manager, but
                    # I really dislike that this bloats the code and forces us to repeat code
                    # looking at the source of that context manager, it just toggles this variable
                    self.model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
                with self.settings.ctx:
                    ab_similarity, loss = self.model(A, B, Y)
                    loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation

                # immediately async prefetch next batch while model is doing the forward pass on the GPU
                A, B, Y = self.get_batch()
                # backward pass, with gradient scaling if training in fp16
                self.scaler.scale(loss).backward()

            # clip the gradient
            if grad_clip != 0.0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)

            # step the optimizer and scaler if training in fp16
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            # evaluate the loss on train/val sets and write checkpoints
            if self.iter_num % eval_interval == 0 and self.settings.master_process:
                losses = self.estimate_loss()
                overfit = (losses['val'] - losses['train']) / losses['train'] * 100
                train_loss = losses['train']
                status = f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} (previous best {self.best_val_loss:.4f}, overfit: {overfit:.2f}%)"
                summary_data.append({
                    'model': embedding_model,
                    'dataset': dataset_type,
                    'iter_num': self.iter_num,
                    'train_loss': losses['train'].item(),
                    'val_loss': losses['val'].item(),
                    'overfit': overfit.item(),
                    'n_params': math.prod(self.model.module_dict['linear'].weight.shape),
                    'endpoint': self.config["data"]["endpoint"],
                })
                summary.append(status)
                print(status)

                visualize.weights(self.model, iter_num=self.iter_num, out_dir=out_dir, loss=train_loss, noplot=noplot)
                #visualize.similarity(self.model, iter_num=self.iter_num, out_dir=out_dir, loss=train_loss, noplot=noplot)
                visualize.grad(self.model, iter_num=self.iter_num, out_dir=out_dir, loss=train_loss, noplot=noplot)

                if losses['val'] < self.best_val_loss:
                    no_improvement_counter = 0
                else:
                    no_improvement_counter += 1

                if losses['val'] < self.best_val_loss or (self.settings.always_save_checkpoint and self.iter_num % self.settings.checkpoint_save_interval == 0):
                    self.best_val_loss = losses['val'] if losses['val'] < self.best_val_loss else self.best_val_loss
                    if self.iter_num > 0:
                        checkpoint = {
                            'model': raw_model.state_dict(),
                            'optimizer': self.optimizer.state_dict(),
                            'model_args': self.config["model"],
                            'iter_num': self.iter_num,
                            'best_val_loss': self.best_val_loss,
                            'config': self.config,
                            'settings': self.settings,
                        }
                        if self.settings.allow_save:
                            print(f"Saving checkpoint at {self.iter_num} to '{out_dir}'")
                            torch.save(checkpoint, os.path.join(out_dir, f"{dataset_type}_{embedding_model}_{self.timestamp}_ckpt_{self.iter_num:03}_{losses['val']:.3f}.pt"))
            
            # flush the gradients as soon as we can, no need for this memory anymore
            #self.optimizer.zero_grad(set_to_none=True)

            # timing and logging
            t1 = time.time()
            dt = t1 - t0
            t0 = t1
            if self.iter_num % self.settings.log_interval == 0 and self.settings.master_process:
                # get loss as float. note: this is a CPU-GPU sync point
                # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
                lossf = loss.item() * gradient_accumulation_steps
                print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
            self.iter_num += 1
            local_iter_num += 1

            # termination conditions
            if self.iter_num > max_iters or (early_stop and no_improvement_counter >= early_stop):
                break

        print(f"Finished training loop at iteration {self.iter_num} (endpoint: {self.config['data']['endpoint']})")
        print("\n".join(summary))
        self.save_settings(summary=summary)
        return summary_data 

    def reset_training(self):
        self.iter_num = 0
        self.best_val_loss = float('inf')
        self.init_model()

    def save_settings(self, **kwargs):
        filename = os.path.join(self.settings.model_dir, f"{self.config['data']['embedding_model']}_{self.config['data']['dataset_name']}_{self.timestamp}.json")
        with open(filename, "w") as f:
            f.write(json.dumps({"config": self.config, "settings": self.settings.json(), **kwargs}, indent=4))

    def save_model(self, filename_prefix:str):
        """
          Save the model configuration and state to a file.
        """
        # save the model config and state
        self.model.save(filename_prefix, iter_num=self.iter_num, timestamp=self.timestamp)

    def kick_tires(self):
        """
          Make one forward and backward pass
        """
        A, B, Y = self.get_batch()

        gas = self.config["train"]["gradient_accumulation_steps"]
        self.model.train()
        with self.settings.ctx:
            # make one forward pass
            ab_similarity, loss = self.model(A, B, Y)
            loss = loss / gas 
            # make one backward pass
            self.scaler.scale(loss).backward()

        print(f"loss = {loss.item():.4f}, similarity = {ab_similarity.mean().item():.4f}")

        visualize.activation([self.model.module_dict['linear'].weight], noplot=True)
        visualize.grad(self.model, noplot=True)

    def init_model(self):
        print("Initializing a new model from scratch")

        # make sure the input dimension matches embedding size
        model_config = self.config.get('model', {})
        model_config['in_dim'] = self.data['train']['a'].shape[1]
        if model_config['out_dim'] == -1:
            # set out_dim to in_dim
            model_config['out_dim'] = model_config['in_dim']
        self.config['model'] = model_config

        model = SimilarityLearner(ModelConfig(**model_config, device=self.device))
        model.to(self.device)

        # initalize accessories that are required for training

        # initialize a GradScaler. If enabled=False scaler is a no-op
        # grad scaler is required to compensate for the loss of accuracy
        # when calculating gradients with half precision floats
        self.scaler = torch.amp.GradScaler('cuda', enabled=(self.settings.dtype == 'float16'))

        config = self.config["train"]
        self.optimizer = model.configure_optimizers(
          config["weight_decay"],
          config["learning_rate"],
          (config["beta1"], config["beta2"]),
          self.settings.device_type
        )
        self.model = model

    def get_lr(self) -> float:
        """
          Get learning rate depending on current iteration
          and training configuration.
        """
        it = self.iter_num
        c = self.config["train"]
        warmup_iters = c["warmup_iters"]
        learning_rate = c["learning_rate"]
        lr_decay_iters = c["lr_decay_iters"]
        min_lr = c["min_lr"]

        # 1) linear warmup for warmup_iters steps
        if it < warmup_iters:
            return learning_rate * it / warmup_iters
        # 2) if it > lr_decay_iters, return min learning rate
        if it > lr_decay_iters:
            return min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
        return min_lr + coeff * (learning_rate - min_lr)

    @torch.no_grad()
    def estimate_loss(self) -> dict:
        """
          Estimate loss in training and validation sets by evaluating
          the model on {eval_iters} samples (batches).
        """
        out = {}
        eval_iters = self.settings.eval_iters

        # set model to prediction mode (don't keep info for gradients)
        self.model.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                a, b, targets = self.get_batch(split)
                with self.settings.ctx:
                    similarity, loss = self.model(a, b, targets)
                losses[k] = loss.item()
            out[split] = losses.mean()
        # set model to training mode (keep info for gradients)
        self.model.train()
        return out

    def get_batch(self, split:str="train") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size = self.config["train"]["batch_size"]
        device = self.device

        data = self.data[split]

        # randomly sample {batch_size} embeddings
        ix = torch.randint(len(data['a']), (batch_size,))

        result = {}
        for key in ['a', 'b', 'y']:
            result[key] = torch.stack([torch.from_numpy(data[key][i, :]) for i in ix])
            if self.settings.device_type == 'cuda':
                # pin arrays to move them to GPU asynchronously (non_blocking=True)
                result[key] = result[key].pin_memory().to(device, non_blocking=True)
            else:
                result[key] = result[key].to(device)

        return tuple(result[key] for key in ['a', 'b', 'y'])
