import torch
import torch.nn as nn
import time
from torchvision.models import swin_t, Swin_T_Weights, vit_b_32, ViT_B_32_Weights
from custom_op.register import register_normal_linear, register_INSTANT, register_LBPWHT, register_GF
from timm import create_model

from calibration import *
import gc

import os

class ModelTrainer:
    def __init__(self, model_name, batch_size, num_epochs, device='cuda',
                 with_base=False, with_INSTANT=False, over_sampling = 0, with_GF=False, with_LBPWHT=False,
                 dataloader=None, output_channels=None, num_of_finetune=None, 
                 explained_var=None, checkpoint=None):
        
        self.device = device
        self.model_name = model_name
        self.output_channels = output_channels
        self.model_dict = self.get_model(model_name, checkpoint)
        
        self.batch_size = batch_size
        self.dataloader = dataloader
        self.num_epochs = num_epochs

        if self.model_name == "swinT" or self.model_name == "vit_b_32" or self.model_name == "efficientformer_l1":
            self.all_linear_layers = self.get_all_linear_with_name()
            self.all_linear_layers
            if num_of_finetune == "all" or num_of_finetune > len(self.all_linear_layers):
                print("[Warning] Finetuning all layers")
                self.num_of_finetune = len(self.all_linear_layers)
            else:
                self.num_of_finetune = num_of_finetune
                print("Number of finetuned layers: ", self.num_of_finetune)

        self.explained_var = explained_var
        self.with_base = with_base
        self.with_INSTANT = with_INSTANT
        self.with_GF = with_GF
        self.with_LBPWHT = with_LBPWHT
        self.over_sampling = over_sampling
        self.backward_time = []
        self.forward_time = []
        self.inference_time = []

        self.config_model(self.backward_time, self.forward_time, self.inference_time)


    def get_all_linear_with_name(self):
        linear_layers = {}
        for name, mod in self.model_dict['model'].named_modules():

            if (isinstance(mod, nn.Linear)) or (isinstance(mod, nn.Conv2d) and mod.kernel_size == (1, 1)):
                if "classifier" in name or "head.head" in name or "bert.pooler" in name or "head" in name:
                    continue
                linear_layers[name] = mod
        return linear_layers
    
    def get_model(self, model_name, checkpoint):
        if model_name == 'swinT': model = swin_t(weights=Swin_T_Weights.DEFAULT)
        elif model_name == 'vit_b_32': 
            if checkpoint is not None:
                pruned_dict = torch.load(checkpoint, weights_only=False, map_location='cpu')
                model = pruned_dict['model']
            else: model = vit_b_32(weights=ViT_B_32_Weights.IMAGENET1K_V1)

        elif model_name == 'efficientformer_l1':
            if checkpoint is not None:
                pruned_dict = torch.load(checkpoint, weights_only=False, map_location='cpu')
                print(pruned_dict.keys)
                model = pruned_dict['state_dict']
            else:
                model = create_model('efficientformer_l1', pretrained=True, num_classes=self.output_channels)

        if model_name == 'swinT':
            model.head = nn.Linear(in_features=768, out_features=self.output_channels, bias=True) # Change classifier
        elif model_name == 'vit_b_32':
            model.heads = nn.Sequential(nn.Linear(in_features=768, out_features=self.output_channels, bias=True)) # Change classifier
        
        model.to(self.device)
        
        return {"model": model, "name": model_name}

    def freeze_layers(self, num_of_finetune):
        if self.model_name != 'swinT' and self.model_name != 'vit_b_32' and self.model_name != 'efficientformer_l1':
            return
        
        all_layers = self.all_linear_layers

        finetuned_layers = dict(list(all_layers.items())[-num_of_finetune:])
        # 1) Freeze everything
        for p in self.model_dict['model'].parameters():
            p.requires_grad = False
        self.model_dict['model'].eval()
        # 2) Unfreeze only the target blocks
        for name, mod in self.model_dict['model'].named_modules():
            if any(k in name for k in finetuned_layers.keys()):
                # set train mode only if you actually want dropout/BN stats updated here
                mod.train()
                for p in mod.parameters():          # recursive is fine here
                    p.requires_grad = True
                print("UNFROZEN:", name)
        trainable = [name for name, p in self.model_dict['model'].named_parameters() if p.requires_grad]
        print("Trainable parameters:", trainable)
        return finetuned_layers

    def config_model(self, backward_time, forward_time, inference_time):
        
        finetuned_layers = self.freeze_layers(self.num_of_finetune)
        print("finetuned layers: ", finetuned_layers)
        filter_cfgs = { "finetuned_layer": finetuned_layers, 
                        "type": "conv",
                        "backward_time": backward_time,
                        "forward_time": forward_time,
                        "inference_time": inference_time}
        if self.with_base: 
            filter_cfgs["type"] = "linear"
            register_normal_linear(self.model_dict['model'], filter_cfgs)

        elif self.with_GF:
            filter_cfgs["type"] = "linear"
            register_GF(self.model_dict['model'], filter_cfgs)

        elif self.with_LBPWHT:
            filter_cfgs["type"] = "linear"
            register_LBPWHT(self.model_dict['model'], filter_cfgs)

        elif self.with_INSTANT:
            add_activate_to_module()
            add_compression_tensor_to_module()

            filter_cfgs["type"] = "linear"

            self.model_dict['model'] = register_INSTANT(self.model_dict['model'], filter_cfgs)


    def train_model(self):
        print("Training begin ...")

        compression_iter = 205
        calib_iter = 5

        optimizer = torch.optim.Adam(self.model_dict['model'].parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(self.num_epochs):
            self.model_dict['model'].train()
            for i, (inputs, labels) in enumerate(self.dataloader):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                if self.with_INSTANT:
                    if i % compression_iter < calib_iter:
                        print(f"Calibration iteration: {i+1}")
                        if i % compression_iter ==0:
                            self.model_dict['model'].disable_compression()
                            register_hooks(self.model_dict['model'])
                        
                        for j in range(2):
                            input_sample = inputs[(32*j):(32*j+32)]
                            label_sample = labels[(32*j):(32*j+32)]
                            outputs = self.model_dict['model'](input_sample)
                            loss = criterion(outputs, label_sample)
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                        if i % compression_iter == calib_iter-1:
                            for layer_name, input_sum in layer_inputs.items():  
                                average_X_XT = input_sum / calib_iter
                                
                                U_dict[layer_name] = SVD_expected_value(average_X_XT, self.explained_var, self.over_sampling )

                            for layer_name, grad_sum in layer_gradients.items(): 
                                
                                average_grad_X_XT = grad_sum / calib_iter
                                U_grad_dict[layer_name] = SVD_expected_value(average_grad_X_XT, self.explained_var, self.over_sampling )

                            self.model_dict['model'].update_compression(U_dict, U_grad_dict)

                            unregister_hooks()

                    else: 
                        gc.collect()  # Force garbage collection
                        torch.cuda.empty_cache()  # Release CUDA cache
                        print(f"Training iteration {i-4} with INSTANT!")
                        self.model_dict['model'].enable_compression()
                        # start = time.time()
                        outputs = self.model_dict['model'](inputs)
                        loss = criterion(outputs, labels)
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                else:
                    print(f"Training iteration {i+1}!")

                    outputs = self.model_dict['model'](inputs)

                    loss = criterion(outputs, labels)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

        print("Done", end='\n')


    def warmup_model(self, warmup_steps=5):
        optimizer = torch.optim.Adam(self.model_dict['model'].parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        print(f"Starting warm-up for {warmup_steps} steps...")
        self.model_dict['model'].train()
        for i, (inputs, labels) in enumerate(self.dataloader):
            if i >= warmup_steps:
                break
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            outputs = self.model_dict['model'](inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("Warm-up completed.")
