import torch
import torch.nn as nn
from dataclasses import dataclass
from datetime import datetime
import copy
from typing import Optional, Dict, Any

from utils.utils import set_seed, parse_args, DefaultArgs, gumbel_softmax
from utils.data_gen import generate_freq_data
from models.models import CSDatastructureModel, CSMLPQueryModel, Embedder

@dataclass
class RunState:
    """Stores the state of a training run including model states and metrics."""
    args: Optional[Dict[str, Any]] = None
    step_best: int = 0
    step_best_hard: int = 0
    step_last: int = 0
    loss_best: float = float('inf')
    loss_best_hard: float = float('inf')
    loss_last: float = float('inf')
    
    # Model states
    query_model_state_last: Optional[Dict] = None
    query_model_state_best: Optional[Dict] = None
    query_model_state_best_hard: Optional[Dict] = None
    optimizer_state_last: Optional[Dict] = None
    optimizer_state_best: Optional[Dict] = None
    optimizer_state_best_hard: Optional[Dict] = None
    data_model_state_last: Optional[Dict] = None
    data_model_state_best: Optional[Dict] = None
    data_model_state_best_hard: Optional[Dict] = None
    embedder_state_last: Optional[Dict] = None
    embedder_state_best: Optional[Dict] = None
    embedder_state_best_hard: Optional[Dict] = None

@dataclass
class Args(DefaultArgs):
    """Configuration parameters for training."""
    # Training parameters
    batch_size: int = 512
    n_steps: int = 2000000
    weight_decay: float = 0.0001
    eval_batch_size: int = 10000

    # Learning rates
    query_model_lr: float = 0.0001
    datastructure_model_lr: float = 0.00005
    embedder_lr: float = 0.0
    prednet_lr: float = 0.00001

    # Architecture - Query MLP
    query_model_arch: str = 'mlp_only_query'
    query_mlp_hidden_dim: int = 1024
    query_mlp_num_layers: int = 3
    query_mlp_act_fn: str = 'relu'
    share_data_query: bool = True
    pred_network: str = 'scalar'
    adaptive: bool = True

    # Architecture - Datamodel Transformer
    datamodel_n_layer: int = 3
    n_embd: int = 32
    data_model_sample_rate: int = 4

    # Task parameters
    n_state: int = 10
    n_stream: int = 2
    final_n_stream: int = 2
    n_inc: int = 5
    n_interval: int = 2000
    n_vals: int = 10
    zipf_alpha: float = 0.5
    dist_type: str = 'zipf'
    shuffle_ordering: bool = True
    fix_value: bool = True
    max_queries: int = 2

    # Gumbel parameters
    gumbel_temp: float = 2.0
    gumbel_int_noise: float = 1.0
    gumbel_noise_scale: float = 1.0

    # Logging and saving
    save_model: bool = True
    save_freq: int = 2000
    log_freq: int = 500
    run_eval_freq: int = 2000

def query_and_compute(
    query_model: CSMLPQueryModel,
    data_model: CSDatastructureModel,
    embedder: Embedder,
    raw_stream: torch.Tensor,
    raw_queries: torch.Tensor,
    args: Args,
    hard: bool = False,
    n_resample: int = 1
):
    """Process stream data and compute query results."""
    embeds = embedder(torch.arange(args.n_vals).to(args.device))
    
    ## Generate data structure: states is a tensor of size batch_size X n_stream X n_state
    ## where states[b, t] denotes the state of the data structure at time step t in the stream
    states, update_masks, update_vals, all_outs = data_model(
        hard=hard,
        temp=args.gumbel_temp,
        embeds=embeds,
        stream=raw_stream[:len(raw_stream) // n_resample],
        noise_scale=args.gumbel_noise_scale
    )
    states = states.repeat(n_resample, 1, 1)
    
    masks = []
    values = []
    B, N = raw_queries.shape

    ## Iteratively query data structure
    for i in range(args.max_queries):
        if args.share_data_query:
            mask_i = all_outs[i][raw_queries.flatten()].view(B, N, -1)[:, :, :-1]
        else:
            queries = embedder(raw_queries)
            mask_i = query_model(queries.view(B*N, args.n_embd), values, masks, i).view(B, N, -1)

        mask_i = (torch.nn.functional.one_hot(mask_i.argmax(dim=-1), mask_i.size(-1)).float() 
                 if hard else gumbel_softmax(mask_i, temperature=args.gumbel_temp, 
                                           hard=False, noise_scale=args.gumbel_noise_scale))

        value_i = (states * mask_i).sum(-1)
        masks.append(mask_i.view(B*N, -1))
        values.append(value_i.view(B*N, -1))

    preds = torch.stack(values, dim=1).view(B, N, args.max_queries).transpose(2, 1)
    masks = torch.stack(masks, dim=1).view(B, N, args.max_queries, -1).transpose(2, 1)

    ## If we have multiple queries, we use a small mlp (pred_model) to predict from the retrieved elements
    if args.max_queries > 1:
        final_pred_input = preds.transpose(2, 1).view(-1, args.max_queries)
        final_pred = query_model.pred_model(final_pred_input)
        final_pred = final_pred.view(B, N, -1).transpose(2, 1)
        preds = torch.cat([preds, final_pred], dim=1)

    return preds, masks, states, update_masks, update_vals

def create_optimizer(
    args: Args,
    query_model: CSMLPQueryModel,
    data_model: CSDatastructureModel,
    embedder: Embedder
) -> torch.optim.Optimizer:
    params = [
        {"name": "query_model.query_model", "params": query_model.query_model.parameters(), "lr": args.query_model_lr},
        {"name": "data_model", "params": data_model.parameters(), "lr": args.datastructure_model_lr},
        {"name": "embedder", "params": embedder.parameters(), "lr": args.embedder_lr},
    ]
    
    if hasattr(query_model, 'pred_model'):
        params.append({
            "name": "query_model.pred_model",
            "params": query_model.pred_model.parameters(),
            "lr": args.prednet_lr
        })

    return torch.optim.AdamW(params, weight_decay=args.weight_decay)

def train_model(
    run_state: RunState,
    args: Args,
    optimizer: torch.optim.Optimizer,
    data_model: CSDatastructureModel,
    query_model: CSMLPQueryModel,
    embedder: Embedder
) -> None:
    step_start = datetime.now()
    
    for step in range(args.n_steps):
        # Update stream size periodically, we found it helpful 
        # to use a curriculum from shorter to longer streams
        if step % args.n_interval == 0 and step > 0:
            if args.n_stream < args.final_n_stream:
                args.n_stream = min(args.n_stream + args.n_inc, args.final_n_stream)
                run_state.loss_best_hard = float('inf')
                run_state.loss_best = float('inf')
                print('New N Stream:', args.n_stream)

        # Generate training data
        stream, queries, all_counts, query_counts = generate_freq_data(
            args.batch_size,
            args.n_stream,
            args.n_vals,
            args.dist_type,
            args.zipf_alpha,
            args.device,
            n_resample=args.data_model_sample_rate,
            shuffle_ordering=args.shuffle_ordering
        )

        # Generate data structure and queries
        preds, masks, states, update_masks, update_vals = query_and_compute(
            query_model,
            data_model,
            embedder,
            stream,
            queries,
            args,
            n_resample=args.data_model_sample_rate
        )

        # Compute loss and update
        loss_all_queries = nn.functional.smooth_l1_loss(
            preds, query_counts[:, None], reduce=False, beta=0.2).mean([0, 2])
        loss = loss_all_queries[-1]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % args.log_freq == 0:
            time_per_step = (datetime.now() - step_start).total_seconds() / args.log_freq
            step_start = datetime.now()

            stream, queries, all_counts, query_counts = generate_freq_data(
                args.eval_batch_size, 
                args.n_stream, 
                args.n_vals, 
                args.dist_type,
                args.zipf_alpha, 
                args.device, 
                n_resample=1, 
                shuffle_ordering=args.shuffle_ordering
                )
            
            with torch.no_grad():
                preds, masks, states, update_masks, update_vals  = query_and_compute(
                    query_model, 
                    data_model, 
                    embedder, 
                    stream, 
                    queries, 
                    args, 
                    n_resample=1, 
                    hard=True)
                
            hard_loss = (preds - query_counts[:, None]).abs().mean([0, 2])[-1]
            hard_acc = (preds.round().int()[:, -1] == query_counts).float().mean()

            metrics = {
                'Step': f'{step:,}',
                'Total Loss': f'{loss.item():.3f}',
                'Hard Loss': f'{hard_loss.item():.3f}',
                'Hard Acc': f'{hard_acc.item():.2%}',
                'TPS': f'{time_per_step:.4f}'
            }
            print('\n' + ' | '.join(f'{k}: {v}' for k, v in metrics.items()))

            if hard_loss < run_state.loss_best_hard:
                run_state.loss_best_hard = hard_loss
                run_state.step_best_hard = step
                run_state.query_model_state_best_hard = copy.deepcopy(query_model.state_dict())
                run_state.data_model_state_best_hard = copy.deepcopy(data_model.state_dict())
                run_state.embedder_state_best_hard = copy.deepcopy(embedder.state_dict())
                run_state.optimizer_state_best_hard = copy.deepcopy(optimizer.state_dict())

            if step % args.save_freq == 0 and args.save_model:
                run_state.step_last = step
                run_state.query_model_state_last = query_model.state_dict()
                run_state.data_model_state_last = data_model.state_dict()
                run_state.embedder_state_last = embedder.state_dict()
                run_state.optimizer_state_last = optimizer.state_dict()
                save_run_state(run_state)

    run_state.step_last = step
    run_state.loss_last = loss
    run_state.query_model_state_last = query_model.state_dict()
    run_state.data_model_state_last = data_model.state_dict()
    run_state.embedder_state_last = embedder.state_dict()
    run_state.optimizer_state_last = optimizer.state_dict()

def save_run_state(run_state):
    torch.save(run_state, run_state.args.out_dir/'out.pk')

def create_optimizer(args: Args, query_model: CSMLPQueryModel, data_model, embedder):
    params = [
        {"name": "query_model.query_model", "params": query_model.query_model.parameters(), "lr": args.query_model_lr},
        {"name": "data_model", "params": data_model.parameters(), "lr": args.datastructure_model_lr},
        {"name": "embedder", "params": embedder.parameters(), "lr": args.embedder_lr},
    ]
    if hasattr(query_model, 'pred_model'):
        params.append({"name": "query_model.pred_model", "params": query_model.pred_model.parameters(), "lr": args.prednet_lr})

    return torch.optim.AdamW(params, weight_decay=args.weight_decay)  

def main():
    args: Args = parse_args(Args)
    run_state = RunState(args=args)
    
    set_seed(args.seed)
    if args.save_model:
        args.out_dir.mkdir(parents=True, exist_ok=True)

    # Initialize models
    query_model = CSMLPQueryModel(args)
    data_model = CSDatastructureModel(
        d_in=args.n_embd,
        n_state=args.n_state,
        n_layer=args.datamodel_n_layer,
        n_queries=args.max_queries,
        fix_value=args.fix_value
    ).to(args.device)
    embedder = Embedder(n_vals=args.n_vals, n_embd=args.n_embd).to(args.device)

    optimizer = create_optimizer(args, query_model, data_model, embedder)

    print(data_model)
    print(query_model)
    if args.share_data_query:
        query_model.query_model = None

    # Move models to device
    data_model = data_model.to(args.device)
    query_model = query_model.to(args.device)
    embedder = embedder.to(args.device)

    # Train
    start_time = datetime.now()
    train_model(run_state, args, optimizer, data_model, query_model, embedder)
    
    if args.save_model:
        print(f"Best loss: {run_state.loss_best}")
        print(f'Saving model to:\n{args.out_dir/"out.pk"}')
        save_run_state(run_state)
    
    print('Total time:', datetime.now() - start_time)

if __name__ == '__main__':
    main()
