

import copy
from dataclasses import dataclass
import os
from typing import List, Optional

import torch
import tqdm
from .data_tools import CustomDataset, eval, batch_end_callback, label_batch
from .trainer import Trainer
from .model import GPT

@dataclass
class FactorizedRepresentations:
    before_factorization: torch.Tensor
    factorized: torch.Tensor
    factorized_target: torch.Tensor
    after_reprojection: torch.Tensor

class StreamAligner():
    
    def __init__(
        self,
        model,
        prefix_size,
        task_specific_model_iter: int = 20000,
        task_specific_model_lr: float = 5e-5,
    ):
        self.model = model
        assert(model.config.n_head == 4), "We expect that the model has four heads."
        assert(model.config.n_embd % 4 == 0), "We expect that the model n_embd is divisible by 4."

        self.prefix_size = prefix_size
        self.task_specific_model_iter = task_specific_model_iter
        self.task_specific_model_lr = task_specific_model_lr

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.indiv_models = {}

    def train_individual_models(self):
        self.indiv_model_config = copy.deepcopy(self.model.config)
        self.indiv_model_config.n_head = 1
        self.indiv_model_config.n_embd = self.model.config.n_embd // 4

        train_config = Trainer.get_default_config()
        train_config.learning_rate = self.task_specific_model_lr
        train_config.max_iters = self.task_specific_model_iter
        train_config.num_workers = 0
        train_config.batch_size = 128
        train_config.device = self.device

        for task in ["ascending", "descending", "add1", "add2"]:
            train_dataset = CustomDataset('train', mode=task, prefix_padding=self.prefix_size)
            test_dataset = CustomDataset('test', mode=task, prefix_padding=self.prefix_size)

            fname = f'align_indiv_model_{task}.pth'
            self.indiv_models[task] = GPT(self.indiv_model_config).to(self.device)

            if os.path.exists(fname):
                self.indiv_models[task].load_state_dict(torch.load(fname), strict=True)
                print(f"Model for {task} loaded from cache.")
            else:
                print(f"TRAINING A MODEL FOR THE {task.upper()} TASK:")
                trainer = Trainer(train_config, self.indiv_models[task], train_dataset=train_dataset)
                trainer.set_callback('on_batch_end', batch_end_callback)
                trainer.run()
                torch.save(self.indiv_models[task].state_dict(), fname)

            print(f"Testing for {task} task: ", end="")
            _ = eval(self.indiv_models[task], dataset=test_dataset, device=self.device, max_batches=128)
            print()

    def get_target_aligned_representation(self, x):
        if any([key not in self.indiv_models for key in ["ascending", "descending", "add1", "add2"]]):
            self.train_individual_models()
    
        # get activations for the four tasks:
        fa = torch.zeros((self.model.config.n_layer, x.size(0), x.size(1), self.model.config.n_embd+2))
        with torch.no_grad():
            for task_i, task in enumerate(["ascending", "descending", "add1", "add2"]):
                record = []
                _ = self.indiv_models[task].forward(x, activations_record=record)
                for layer_idx, acts in enumerate([a[1] for a in record if a[0]=="attn_block_output"]):
                    fa[layer_idx,:,:,task_i*self.indiv_model_config.n_embd:(task_i+1)*self.indiv_model_config.n_embd] = acts

        return fa

    def align_model(self):

        if any([key not in self.indiv_models for key in ["ascending", "descending", "add1", "add2"]]):
            self.train_individual_models()

        # Now train the projection layers for the original model to resemble 
        # the attention block activations of the four individual models

        params_to_optimize = []
        for n, p in self.model.named_parameters():
            if "factorized" not in n:
                p.requires_grad = False
            else:
                p.requires_grad = True
                params_to_optimize.append(p)
                with torch.no_grad():
                    p.normal_(0, 0.02)
                print(f"Will optimize {n}")

        train_dataset_random = CustomDataset('train', mode="random", prefix_padding=self.prefix_size) 
        train_loader = torch.utils.data.dataloader.DataLoader(
            train_dataset_random,
            sampler=torch.utils.data.RandomSampler(
                train_dataset_random,
                replacement=True, 
                num_samples=10_000_000,
            ),
            shuffle=False,
            pin_memory=True,
            batch_size=512,
            num_workers=0,
        )

        optimizer = torch.optim.AdamW(params_to_optimize, lr=5e-5)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)

        # training loop
        pbar = tqdm.tqdm(train_loader, total=len(train_loader), desc="Optimizing factorization maps")
        self.model.train()
        for batch_idx, batch in enumerate(pbar):
            batch = [t.to(self.device) for t in batch]
            x, y = batch

            fa = self.get_target_aligned_representation(x)

            # forward the model
            record = []
            logits, lang_loss = self.model(x, y, projection=True, activations_record=record)
            fa_predicted = [a[1] for a in record if a[0]=="factorized_pred"]

            # calculate the factorization loss
            factorization_loss_per_layer = [torch.nn.functional.mse_loss(
                    pred.flatten(end_dim=1)[:,:-2], 
                    target.flatten(end_dim=1)[:,:-2]
                ) for pred, target in zip(fa_predicted, fa)]
            factorization_loss = torch.tensor(factorization_loss_per_layer).sum()
            

            # backprop and update the parameters 
            self.model.zero_grad(set_to_none=True)
            (lang_loss+10*factorization_loss).backward()
            torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0)
            optimizer.step()
            scheduler.step()
            pbar.set_postfix({
                "lang_loss": f"{lang_loss.item():.3e}", 
                "factorization_loss": f"{factorization_loss.item():.3e}",
                # "loss_invertibility": f"{loss_invertibility.item():.3e}",
                })

        self.model.eval()    

        return self

    def get_factorized_representation(self, x: torch.Tensor, prefixes: Optional[torch.Tensor]) -> FactorizedRepresentations:

        factorized_target = self.get_target_aligned_representation(x).detach().cpu()
        record = []
        _ = self.model(x, projection=True, activations_record=record, prefixes=prefixes)
        before_factorization = torch.stack([a[1] for a in record if a[0]=="attn_block_output"], dim=0)
        factorized = torch.stack([a[1] for a in record if a[0]=="factorized_pred"], dim=0)
        after_reprojection = torch.stack([a[1] for a in record if a[0]=="attn_fwd_reconstructed"], dim=0)

        return FactorizedRepresentations(
            before_factorization=before_factorization,
            factorized=factorized,
            after_reprojection=after_reprojection,
            factorized_target=factorized_target,
        )