import sys
import tracemalloc
import time

from src.models.layers import *

from src.models.utils import *
# from chronos import ChronosPipeline
# from src.models.model import *
# from src.models.crossvit import *
from src.models.mae import *

class ViTAdjust(nn.Module):
    def __init__(self, finetune_layers=list(), upper_layer=-1):
        super().__init__()
        self.vit = ViT('B_16_imagenet1k', pretrained=True) # construct and load pretrained weight

        self.trans = nn.Sequential(
            CheckShape(None, key=lambda x: (x - 0.5) / 0.5),
            CheckShape(None, key=lambda x: nn.functional.pad(x, (0, 0, 16, 0), value=0))
        )

        # overwrite Conv2D patching layer, make the stride to be (4, 16)
        # maintain the 16 on the width dimension, change height dimension to 4
        # we have 65 on width, L on height. Suppose to be
        self.vit.patch_embedding.stride = (4, 16)

        # overwrite position embedding
        self.vit.positional_embedding = VitPosEmbedAdjust(self.vit.positional_embedding.pos_embedding)

        freeze_model(self.vit)

        # fine tune certain layer
        self.upper_layer = upper_layer
        self.finetune_layers = finetune_layers
        for layer in finetune_layers:
            unfreeze_model(self.vit.transformer.blocks[layer])

        # remove never-used modules
        self.vit.fc = None
    
    def train(self, mode):
        super().train(mode)

        # keep freeze layer stay in eval mode
        for b_i in range(len(self.vit.transformer.blocks)):
            if b_i not in self.finetune_layers:
                self.vit.transformer.blocks[b_i].eval()
    
    def patch_forward(self, x):
        # transform
        N, _, L, F = x.shape
        # out = transforms.Compose([
        #     # transforms.Resize((L+16, 65)), 
        #     transforms.Normalize(0.5, 0.5),
        #     lambda im: nn.functional.pad(im, (0, 0, 16, 0), value=0),
        # ])(x) # N, 3, L, 384

        out = self.trans(x)

        # main forward
        # patching with overlap
        out = self.vit.patch_embedding(out) # b,d,gh,gw
        out = out.flatten(2).transpose(1, 2) # b,gh*gw,d
        out = out[:, -L:, :] # b,L,d
        
        # concat [CLS] token
        out = torch.cat((self.vit.class_token.expand(N, -1, -1), out), dim=1) # b,L+1,d
        out = self.vit.positional_embedding(out)

        return out # N, L+1, 768
    
    def backbone_forward(self, out, mask=None, hidden_out=False):
        # transformer
        if hidden_out:
            hiddens = list()
        if self.upper_layer < 0:
            for b_i in range(len(self.vit.transformer.blocks)):
                out = self.vit.transformer.blocks[b_i](out, mask=mask)
                if hidden_out:
                    hiddens.append(out.numpy().tolist())
        else: # if want to save cls from intermediate layer
            for b_i in range(self.upper_layer):
                out = self.vit.transformer.blocks[b_i](out, mask=mask)

                # for checking intermediate output
                if hidden_out:
                    hiddens.append(out.numpy().tolist())
            for b_i in range(self.upper_layer, len(self.vit.transformer.blocks)):
                out = self.vit.transformer.blocks[b_i](out, mask=mask)

                # for checking intermediate output
                if hidden_out:
                    hiddens.append(out.numpy().tolist())

        # norm
        out = self.vit.norm(out)

        # return
        if hidden_out:
            return out, hiddens
        return out

    def forward(self, x, mask=None, hidden_out=False): # Input shape: (N, 3, L, 65)
        out = self.patch_forward(x)
        return self.backbone_forward(out, mask=mask, hidden_out=hidden_out)

class CVModel(nn.Module):
    def __init__(self, num_channels=14, num_classes=1, task="reg", is_preproc=False, vanilla=False, all_attn=False):
        super().__init__()
        self.task = task
        self.is_preproc = is_preproc
        self.vanilla = vanilla
        self.all_attn = all_attn
        
        if vanilla:
            # old version vit
            self.trans = transforms.Compose([
                transforms.Resize((384, 384)), 
                # transforms.ToTensor(),
                # transforms.Normalize(0.5, 0.5),
            ])
            self.encoder = freeze_model(init_vit(m_name='B_16_imagenet1k'))
        else:
            # adjust version
            self.encoder = ViTAdjust()

        # final linear probing output
        self.linear_prob = LinearProb(num_channels, 768, num_classes)
    
    def forward(self, x): # Input shape: (N, C, L, 65)
        if not self.is_preproc:
            x = torch.nan_to_num(x, nan=0.0)
            # print(x[:, -5:])
            # exit()
            enc_out = x
        else:
            if self.vanilla:
                N, C, L, F = x.shape
            else:
                N, C, L, F, _ = x.shape
            
            if not self.all_attn:
                enc_out = torch.zeros(N, C, 768).to(torch.bfloat16).to(DEVICE)
                for c in range(C):
                    if self.vanilla:
                        curr_x = self.trans(x[:, c, :, :].unsqueeze(1).expand(N, 3, L, F)) # N, L, E
                    else:
                        curr_x = x[:, c, :, :, :].permute(0, 3, 1, 2)

                    if self.vanilla:
                        curr_enc = self.encoder(curr_x)
                    else:
                        curr_enc = torch.mean(self.encoder(curr_x), dim=1)
                    enc_out[:, c, :] = curr_enc
                return self.linear_prob(enc_out), enc_out
            else: # concat and all attn
                enc_out = self.encoder()
    
class Chronos(nn.Module):
    def __init__(
            self, 
            num_channels, 
            num_classes, 
            task='reg', 
            is_preproc=False,
            fuse_method='mean' # mean, msitf
        ):
        super().__init__()
        self.task = task
        self.is_preproc = is_preproc
        self.fuse_method = fuse_method

        self.encoder = ChronosPipeline.from_pretrained(
            "amazon/chronos-t5-base", # small, base
            device_map="cuda", # cpu, cuda
            torch_dtype=torch.bfloat16,
        )
        # self.encoder = freeze_model(self.encoder)

        if fuse_method == 'mean':
            self.fuse = lambda x: torch.mean(x, dim=1)
        elif fuse_method == 'msitf':
            self.fuse = TemporalFusion(
                num_neurons=512, # 512, 768
                query_size=384, 
                fuse_method='msitf' # last, msitf
            )

        # latent_size = 512 if fuse_method == 'mean' else 384
        latent_size = 768 if fuse_method == 'mean' else 384
        self.linear_prob = LinearProb(num_channels, latent_size, num_classes)
        
    def forward(self, x, query=None, return_scores=False, hidden_out=False): 
        # print(x.shape)
        if not self.is_preproc:
            if self.fuse_method == 'mean':
                out = x # in the very 1st version the shape is already (N, C, E)
                # out = torch.mean(x, dim=2) # previous version

            elif self.fuse_method == 'msitf':
                out = torch.stack([
                    self.fuse(
                        x[:, c, :, :], 
                        query=query, 
                        return_scores=return_scores)
                for c in range(x.shape[1])], dim=1) # N, C, E
                # print(torch.sum(out))
                # exit()
            # out = self.fuse(x, query=query, return_scores=return_scores) # N, C, E
            # out = x
        else:
            # Input shape: (N, C, L)
            N, C, L = x.shape
            x = x.view(N*C, L)
            # out, _ = self.encoder.embed(x) # N*C, L1, 512
            # out = torch.mean(out, dim=1).view(N, C, -1) # N, C, 512


            context_tensor = self.encoder._prepare_and_validate_context(context=x).to(torch.device("cpu"))
            token_ids, attention_mask, tokenizer_state = self.encoder.tokenizer._input_transform(
                context_tensor
            )
            embeddings = self.encoder.model.encode(
                input_ids=token_ids.to(DEVICE),
                attention_mask=attention_mask.to(DEVICE),
            ).to(DEVICE) # N*C, L, E

            # embeddings, tokenizer_state = self.encoder.embed(x)

            # temporal fusion
            _, L, E = embeddings.shape
            embeddings = embeddings.view(N, C, L, E)
            embeddings = torch.mean(embeddings, dim=2)
            # print(embeddings.shape)
            # exit()
            return None, embeddings
            # out = self.fuse(embeddings, query=query, return_scores=return_scores).view(N, C, -1) # N, C, E

        return self.linear_prob(out), out

class PretrainAPI(nn.Module):
    def __init__(self, num_channels=14, num_classes=1, task="reg", is_preproc=False, is_on_cluster=False):
        super().__init__()
        self.encoder = PhysioModel(
            num_layers=6, # 6, total number of transformers
            upper_layer=2, # 2, at which layer output/exchange [CLS]
            emb_size=384, # 384, embedding size, evenly divided by num_heads
            num_heads=12, # 12
            device=DEVICE
        )
        # model_remark = "test_run"
        model_remark = "tiny"

        if is_on_cluster:
            self.encoder.load_state_dict(torch.load("../data/model_checkpoint_{}.pth".format(model_remark), map_location=DEVICE))
        else:
            self.encoder.load_state_dict(torch.load("data/pretrain_model_weights/model_checkpoint_{}.pth".format(model_remark), map_location=DEVICE))

        # state = torch.load("data/pretrain_model_weights/model_checkpoint_test_run.pth", map_location=DEVICE)
        # for k in state:
        #     print(torch.sum(state[k]))
        # exit()

        self.is_preproc = is_preproc
        self.task = task

        self.encoder = freeze_model(self.encoder)
        self.encoder.train()
        self.linear_prob = LinearProb(num_channels, 384, num_classes)
    
    def forward(self, x): # Input shape: (N, C, L, 65, 3)
        if not self.is_preproc:
            x = torch.nan_to_num(x, nan=0.0)
            enc_out = x
        else:
            N, C, L, F, _ = x.shape
            enc_out = torch.zeros(N, C, 384).to(DEVICE)
            for c in range(C):
                curr_x = x[:, c, :, :, :]
                enc_out[:, c, :] = torch.mean(self.encoder.forward_plain(curr_x), dim=1)
        # print(enc_out[:, 0, -5:])
        # import matplotlib.pyplot as plt
        # plt.imshow(enc_out[:, 0, :])
        # plt.show()
        # exit()
        return self.linear_prob(enc_out), enc_out

class FusionModel(nn.Module):
    def __init__(self, num_classes=1, task="reg", is_preproc=False):
        super().__init__()
        self.task = task
        self.is_preproc = is_preproc

        # final linear probing output
        self.linear_prob = nn.Linear(384, num_classes)
    
    def forward(self, x): # Input shape: (N, C, L, 65)
        if not self.is_preproc:
            x = torch.nan_to_num(x, nan=0.0)
            # print(x[:, -5:])
            # exit()
            enc_out = x
        else:
            pass

        return self.linear_prob(enc_out), enc_out

class MAE_API(nn.Module):
    def __init__(self, num_classes, task='reg', is_preproc=False):
        super().__init__()
        # basic properties
        self.task = task
        self.is_preproc = is_preproc

        # init encoder and matcher
        if is_preproc:
            self.encoder = MaskedAutoencoderViT(img_size=(387,65), patch_size=(9,5),mask_scheme='random',mask_prob=0.8,use_cwt=True,nvar=4, comb_freq=True)
            stat_dict = torch.load('../data/results/model_mae_checkpoint-140.pth', map_location=torch.device('cpu'))['model'] # on cluster
            # stat_dict = torch.load('../model_mae_checkpoint-140.pth', map_location=torch.device('cpu'))['model'] # local
            self.encoder.load_state_dict(stat_dict)
            print("Model load successfull.")

            self.matcher = TemporalFusion().to(torch.bfloat16).to(DEVICE)
            # self.matcher.load_state_dict(torch.load("../data/matcher_trained_weights_{}_cf.pt".format("embeddings_mae"), map_location=DEVICE)) # for cross vit msitf fuse, on cluster
            self.matcher.load_state_dict(torch.load("../data/matcher_trained_weights_{}.pt".format("embeddings_mae_var"), map_location=DEVICE)) # local

        # final output layer
        self.linear_prob = nn.Linear(384, num_classes)
        
    def forward(self, x, query=None): 
        if not self.is_preproc:
            out = torch.nan_to_num(x)
        else:
            # N, C, L, 65, 3
            out = self.encoder.forward_all(x.permute(0, 1, 4, 2, 3)) # N, C, L, 768
            
            N, C, L, H = out.shape
            # matcher out
            out = self.matcher(
                out.view(N, C*L, H), 
                query, 
            ) # N, 384

        return self.linear_prob(out), out

class MAE_LP_API(nn.Module):
    def __init__(self, num_channels, num_classes, task='reg', is_preproc=False):
        super().__init__()
        # basic properties
        self.task = task
        self.is_preproc = is_preproc

        # init encoder and matcher
        if is_preproc:
            self.encoder = MaskedAutoencoderViT(img_size=(387,65), patch_size=(9,5),mask_scheme='random',mask_prob=0.8,use_cwt=True,nvar=4, comb_freq=True)
            stat_dict = torch.load('../data/results/model_mae_checkpoint-140.pth', map_location=torch.device('cpu'))['model']
            self.encoder.load_state_dict(stat_dict)
            print("Model load successfull.")

            # for deduction map
            self.matcher = TemporalFusion().to(torch.bfloat16).to(DEVICE)
            self.matcher.load_state_dict(torch.load("../data/matcher_trained_weights_{}.pt".format("embeddings_mae_var"), map_location=DEVICE)) # for cross vit msitf fuse, on cluster
            # self.matcher.load_state_dict(torch.load("matcher_trained_weights_{}_cf.pt".format("embeddings_mae"), map_location=DEVICE)) # local
            self.norm = nn.LayerNorm(384, eps=1e-6)

        # # final output layer
        self.linear_prob = LinearProb(num_channels, 768, num_classes)
        #### for dimension deduction
        # self.linear_prob = LinearProb(num_channels, 384, num_classes)
        
    def forward(self, x): 
        if not self.is_preproc:
            out = torch.nan_to_num(x)
        else:
            # N, C, L, 65, 3
            out = self.encoder.forward_all(x.permute(0, 1, 4, 2, 3)) # N, C, L, 768
            
            # mean pool
            out = torch.mean(out, dim=2) # N, C, 768
            
            
            ### TEMPORAL: for dimension deduction, use v mapping
            # out = self.norm(self.matcher.v(out)) # N, C, 384

        return self.linear_prob(out), out

class CrossVitAPI(nn.Module):
    def __init__(self, num_channels, num_classes, task='reg', is_preproc=False):
        super().__init__()
        self.task = task
        self.is_preproc = is_preproc

        if is_preproc:
            self.encoder = CrossSignalViT(device='cuda')
            stat_dict = torch.load('../data/model_checkpoint_cross_freeze_vit100_99.pth', map_location=torch.device('cpu'))['model']
            self.encoder.load_state_dict(stat_dict)

        self.linear_prob = LinearProb(num_channels, 768, num_classes)
        
    def forward(self, x): 
        if not self.is_preproc:
            out = torch.nan_to_num(x)
        else:
            # N, C, L, 65, 3
            out = self.encoder.forward_all(x.permute(0, 1, 4, 2, 3)) # N, C, L, 768
            out = torch.mean(out, dim=2).bfloat16() # N, C, 768

        return self.linear_prob(out), out

class CrossViTOpt(nn.Module):
    def __init__(self, finetune_layers=list()):
        super().__init__()
        self.vit = ViT('B_16_imagenet1k', pretrained=True) # construct and load pretrained weight

        self.trans = nn.Sequential(
            CheckShape(None, key=lambda x: (x - 0.5) / 0.5),
            CheckShape(None, key=lambda x: nn.functional.pad(x, (0, 0, 16, 0), value=0))
        )

        # overwrite Conv2D patching layer, make the stride to be (4, 16)
        # maintain the 16 on the width dimension, change height dimension to 4
        # we have 65 on width, L on height. Suppose to be
        self.vit.patch_embedding.stride = (4, 16)

        # overwrite position embedding
        self.vit.positional_embedding = VitPosEmbedAdjust(self.vit.positional_embedding.pos_embedding)

        freeze_model(self.vit)

        # fine tune certain layer
        self.finetune_layers = finetune_layers
        for layer in finetune_layers:
            unfreeze_model(self.vit.transformer.blocks[layer])

        # remove never-used modules
        self.vit.fc = None

        # cross attn
        self.cross_attn = Transformer(
            num_layers=12,
            emb_size=768,
            num_heads=12
        )
    
    def train(self, mode):
        super().train(mode)

        # keep freeze layer stay in eval mode
        for b_i in range(len(self.vit.transformer.blocks)):
            if b_i not in self.finetune_layers:
                self.vit.transformer.blocks[b_i].eval()
        
    def patching_forward(self, x): #(N, C, 3, L, 65)
        # transform
        N, C, S, L, F = x.shape
        x = x.view(-1, S, L, F)
        out = self.trans(x)

        # main forward
        # patching with overlap
        out = self.vit.patch_embedding(out) # b,d,gh,gw
        out = out.flatten(2).transpose(1, 2) # b,gh*gw,d
        out = out[:, -L:, :] # b,L,d
        
        # concat [CLS] token
        out = torch.cat((self.vit.class_token.expand(N*C, -1, -1), out), dim=1) # b,L+1,d
        out = self.vit.positional_embedding(out)

        out = out.view(N, C, L+1, 768)
        return out
    
    def transformer_forward(self, x, idx, mask=None):
        # x: (N, C, L, 768)
        N, C, L, H = x.shape
        x = x.view(-1, L, H)

        # self-attn
        out = self.vit.transformer.blocks[idx](x, mask=mask) # N*C, L, 768
        out = out.view(N, C, L, H)

        # cross-attn
        clss = out[:, :, 0, :] # N, C, H
        cls_out = self.cross_attn.blocks[idx](clss, mask=mask) # N, C, H

        # substitute the new cls back
        out[:, :, 0, :] = cls_out
        return out # N, C, L, 768
    
    def forward(self, x, mask=None, hidden_out=False): # Input shape: (N, C, 3, L, 65)
        # transform and patching
        out = self.patching_forward(x)

        # transformer
        if hidden_out:
            hiddens = list()
        for b_i in range(len(self.vit.transformer.blocks)):
            out = self.transformer_forward(out, b_i, mask=None)
            if hidden_out:
                hiddens.append(out.numpy().tolist())

        # norm
        out = self.vit.norm(out)

        # return
        if hidden_out:
            return out, hiddens
        return out

if __name__ == "__main__":
    # setting
    device = torch.device(sys.argv[1])
    num_classes = 10
    num_channels = 6
    sequence_len = 65*6 # 6 seconds
    dtype_ = torch.float32 # torch.bfloat16, torch.float32

    # init model
    # model = ViTAdjust()
    # model = CrossViTOpt()
    tracemalloc.start()
    print('init...')
    model = MAE_API(num_classes, task='class', is_preproc=True).to(dtype_)
    model.to(device)
    model.eval()
    print("model load complete.")
    
    # generate data
    x = torch.rand(1, num_channels, sequence_len, 65, 3).to(dtype_).to(device) # N, C, L, F, 3
    query = torch.rand(1, 384).to(dtype_).to(device)

    # main inference
    print('Inferencing...')
    start = time.time()
    with torch.no_grad():
        y, latent = model(x, query=query)
    end = time.time()

    # log of ram
    current, peak = tracemalloc.get_traced_memory()
    print(f"Current memory usage is {current / 10**3}KB; Peak was {peak / 10**3}KB; Diff = {(peak - current) / 10**3}KB")
    tracemalloc.stop()
    print("VRAM usage:", torch.cuda.mem_get_info())

    # log
    print("DEVICE:", device)
    print("Infer time:", round(end-start, 2), "s")
    print("In shape:", x.shape) # N, C, L, F, 3
    print("Out shape:", y.shape) # N, C, L, 768
    print("Embed shape:", latent.shape) # N, C, L, 768
    