from torch import optim
from tqdm import tqdm
import pickle

from src.models.layers import *

class PhysioModel(nn.Module):
    def __init__(
        self,
        num_layers=6,
        upper_layer=2,
        emb_size=384,
        num_heads=12,
        # data related
        in_channel=65, # number of frequency band/scales
        device=None
    ):
        super().__init__()

        self.device = device

        # for statistics track
        self.stats = dict()

        # embedding of signal tokens (PPG, EEG_F, ECG, etc.)
        with open("nlp_supervise/sensor_embed_reduced.pkl", 'rb') as f:
            self.signal_tk = pickle.load(f)
        for s in self.signal_tk:
            self.signal_tk[s] = torch.tensor(self.signal_tk[s]).float().to(device) # 384

        # bakcbone
        self.backbone_upper = SiT(
            num_layers=upper_layer,
            emb_size=emb_size,
            in_channel=in_channel,
            num_heads=num_heads,
        )
        self.backbone_lower = Transformer(
            num_layers=num_layers-upper_layer,
            emb_size=emb_size,
            num_heads=num_heads,
        )

        # output layers
        self.reconstruct_fc = nn.Linear(emb_size, 1)

        # define losses
        self.l1_loss = nn.L1Loss()
    
    def input_cls(self, cls_out):
        sensor_names = [s for s in self.signal_tk]
        x, y = cls_out, torch.stack([self.signal_tk[s] for s in self.signal_tk])
        cos_sims = torch.sum(x*y, dim=1) / (torch.sqrt(torch.sum(x**2)) * torch.sqrt(torch.sum(y**2, dim=1)))
        return sensor_names[torch.argmax(cos_sims)]

    def forward_plain(self, x):
        x = torch.nan_to_num(
            x.permute(0, 3, 1, 2), # (N, L, F, 3) -> (N, 3, L, F)
            nan=0.0, posinf=0.0, neginf=0.0
        )

        # 1st part forward
        out = self.backbone_upper(x) # (N, L+1, E)
        out = self.backbone_lower(out)
        return out
    
    def forward(self, x, input_mask=None, exchange_tks=None, return_reconstruct=True): 
        '''
        @param x: (N, L, F, 3)
        @param input_mask: (N, 1, L, 1)
        @param exchange_tk: list(str) ["PPG", "EEG_F", "ACC_X", "ECG", etc.]
        '''
        x = torch.nan_to_num(
            x.permute(0, 3, 1, 2), # (N, L, F, 3) -> (N, 3, L, F)
            nan=0.0, posinf=0.0, neginf=0.0
        )

        # apply mask on input
        N, L = input_mask.shape
        x = x * input_mask.view(N, 1, L, 1)

        # 1st part forward
        out = self.backbone_upper(x) # (N, L+1, E)
        cls_out = out[:, 0, :] # (N, E)

        # exchange token for reconstruct/ar output
        if exchange_tks is not None:
            exchange_cls = torch.stack([self.signal_tk[tk] for tk in exchange_tks]).to(self.device)
            out[:, 0, :] = exchange_cls

        # final part forward
        out = self.backbone_lower(out) # (N, L+1, E)
        
        final_out = {
            "embed": out, # (N, L+1, E)
            "cls": cls_out, # (N, E)
        }

        if return_reconstruct:
            final_out["rec"] = self.reconstruct_fc(out) # (N, L+1, 1)
        return final_out
    
    def num_params(self):
        total_num = sum(p.numel() for p in self.parameters())
        train_num = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("Total of Parameters: {}M".format(round(total_num / 1e6, 2)))
        print("Train Parameters: {}M".format(round(train_num / 1e6, 2)))
    
    def loss_f(self, pred_ts, padding_mask, target_ts, pred_cls, target_cls):
        N, L = padding_mask.shape
        pred_ts = (pred_ts * padding_mask.view(N, L, 1))[:, :, 0] # N, L
        losses = [
            0.6*self.l1_loss(pred_ts[:, -1], target_ts[:, -1]), # autoregression
            0.2*self.l1_loss(pred_ts[:, :-1], target_ts[:, :-1]), # reconstruction
            0.2*self.l1_loss(pred_cls, target_cls) # CLS teacher-student
        ]
        # print("Check each loss:", sum([los.item() for los in losses]))
        return sum(losses)

    def fit(
        self,
        dataloader,
        epochs,
        lr=1e-3,
        weight_decay=1e-2
    ):
        optimizer = optim.AdamW(
            self.parameters(),
            lr=lr,
            weight_decay=weight_decay
        )

        # train
        train_losses = list()
        print("Start training")
        for e in tqdm(range(epochs)):
            self.train()
            for data in tqdm(dataloader):
                # forward
                out = self(
                    data["sample_in"], 
                    input_mask=data["mask"], 
                    exchange_tks=data["exchange_tk"], 
                    return_reconstruct=True
                )
                target_cls = torch.stack([self.signal_tk[tk] for tk in data["in_tk"]]).to(self.device)
                loss = self.loss_f(out["rec"], data["target_ts"], data["padding_mask"], out["cls"], target_cls)

                # backprop
                optimizer.zero_grad() # clear cache
                loss.backward() # calculate gradient
                # for p in self.parameters(): # addressing gradient vanishing
                #     if p.requires_grad and p.grad is not None:
                #         p.grad = torch.nan_to_num(p.grad, nan=0.0)
                optimizer.step() # update parameters

                # update record
                train_losses.append(loss.detach().cpu().item())
                for i in range(len(data["in_tk"])):
                    in_tk, out_tk = data["in_tk"][i], data["exchange_tk"][i]
                    comb_tk = "{}-{}".format(in_tk, out_tk)
                    if self.stats.get(in_tk) is None:
                        self.stats[in_tk] = 0
                    if self.stats.get(out_tk) is None:
                        self.stats[out_tk] = 0
                    if self.stats.get(comb_tk) is None:
                        self.stats[comb_tk] = 0
                    self.stats[in_tk] += 1
                    self.stats[out_tk] += 1
                    self.stats[comb_tk] += 1
        self.eval()

        return train_losses

if __name__ == '__main__':
    fake_x = torch.rand((3, 3, 128, 65))

    model = PhysioModel(
        num_layers=6,
        upper_layer=2,
        emb_size=384,
        # data related
        in_channel=65
    )
    out = model(fake_x, input_mask=None, exchange_tks=None)
    print("out shape:", out["embed"].shape) # (N, L+1, E = 3, 129, 384)
    print("cls shape:", out["cls"].shape) # (N, E = 3, 384)
    print("reconstruct shape:", out["rec"].shape) # (N, L+1, 1 = 3, 384, 1)