from __future__ import annotations

import os
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.jit as jit
from tqdm import tqdm, trange


paddle.device.set_device("cpu")
# from paddlenlp.transformers import PretrainedModel

# import torch
# state_dict = torch.load("model_state.pdparams")
# state_dict = {k: paddle.cast(paddle.to_tensor(v.numpy()), paddle.float16) for k, v in state_dict.items()}
# paddle.save(state_dict, "pytorch_model.fp16.bin")

def test_convert_weight():
    import paddle
    state_dict = paddle.load("model_state.pdparams")
    state_dict = {k: paddle.cast(v, paddle.bfloat16) for k, v in state_dict.items()}
    paddle.save(state_dict, "model_state.bf16.pdparams")



class MLP(nn.Layer):
    def __init__(self, name_scope=None, dtype="float32"):
        super().__init__(name_scope, dtype)
        
        self.fc = nn.Linear(2, 3)
        
    def forward(self, inputs):
        return self.fc(inputs)

def test_compile_static_program():
    
    model = MLP()
    model = jit.to_static(model, input_spec=[static.InputSpec(shape=[None, None])])
    path =  "./pretrained/model/model"
    jit.save(model, path)

    executor = static.Executor()

    program = static.load_inference_model(path, executor)


import paddle
from paddle import Tensor


def tensor_info(tensor):
    return f"start to split tensor<{tensor.name}, {tensor.shape}, {tensor.dtype}>"


class Split:
    def __init__(self, name: str, variable_name: str, axis: int = None, rank_size: int = None, n_rank: int = None) -> None:
        self.name = name
        self.variable_name = variable_name

        self.axis = axis
        self.rank_size = rank_size
        self.n_rank = n_rank
    
    def __call__(self, tensor: Tensor):
        # print(tensor_info(tensor))
        if not self.rank_size:
            return tensor
        chunks = paddle.split(tensor, num_or_sections=self.rank_size, axis=self.axis)
        # print(f"splited chunk tensor<{tensor.name}, {chunks[self.n_rank].shape}, {chunks[self.n_rank].dtype}>")
        return chunks[self.n_rank]
    
    def __str__(self) -> str:
        return f"split<{self.name}, {self.variable_name}, {self.axis}, {self.rank_size}, {self.n_rank}>"

def get_bloom_mesh_info(rank_size: int, n_rank: int, layer_size: int) -> list[Split]:
    all_splits = [
        ["bloom.word_embeddings.weight", "embedding_0.w_0", 0, rank_size, n_rank],
        ["bloom.word_embeddings_layernorm.weight", "layer_norm_0.w_0"],
        ["bloom.word_embeddings_layernorm.bias", "layer_norm_0.b_0"],
    ]
    for layer_index in range(layer_size):
        layer_norm_size = 2
        layer_linear_size = 4
        all_splits.extend([
            [f"bloom.h.{layer_index}.input_layernorm.weight", f"layer_norm_{layer_index * layer_norm_size + 1}.w_0"],
            [f"bloom.h.{layer_index}.input_layernorm.bias", f"layer_norm_{layer_index * layer_norm_size + 1}.b_0"],
            [f"bloom.h.{layer_index}.self_attention.query_key_value.weight", f"linear_{layer_index * layer_linear_size + 0}.w_0", -1, rank_size, n_rank],
            [f"bloom.h.{layer_index}.self_attention.query_key_value.bias", f"linear_{layer_index * layer_linear_size + 0}.b_0", -1, rank_size, n_rank],
            [f"bloom.h.{layer_index}.self_attention.dense.weight", f"linear_{layer_index * layer_linear_size + 1}.w_0", 0, rank_size, n_rank],
            [f"bloom.h.{layer_index}.self_attention.dense.bias", f"linear_{layer_index * layer_linear_size + 1}.b_0"],
            [f"bloom.h.{layer_index}.mlp.dense_h_to_4h.weight", f"linear_{layer_index * layer_linear_size + 2}.w_0", -1, rank_size, n_rank],
            [f"bloom.h.{layer_index}.mlp.dense_h_to_4h.bias", f"linear_{layer_index * layer_linear_size + 2}.b_0",  -1, rank_size, n_rank],
            [f"bloom.h.{layer_index}.mlp.dense_4h_to_h.weight", f"linear_{layer_index * layer_linear_size + 3}.w_0", 0, rank_size, n_rank],
            [f"bloom.h.{layer_index}.mlp.dense_4h_to_h.bias", f"linear_{layer_index * layer_linear_size + 3}.b_0"],
            [f"bloom.h.{layer_index}.post_attention_layernorm.weight", f"layer_norm_{layer_index * 2 + 2}.w_0"],
            [f"bloom.h.{layer_index}.post_attention_layernorm.bias", f"layer_norm_{layer_index * 2 + 2}.b_0"],
        ])
    
    all_splits.extend([
        ["bloom.ln_f.weight", f"layer_norm_{layer_size * layer_norm_size + 1}.w_0"],
        ["bloom.ln_f.bias", f"layer_norm_{layer_size * layer_norm_size + 1}.b_0"]
    ])

    return [Split(*split) for split in all_splits]

def get_176b_info():
    info = get_bloom_mesh_info(0, 8, layer_size=70)

def gen_split_weight_file(state_dict, rank_size = 2, n_rank = 1, layer_size = 24, path_prefix: str = "./pretrained/bloom-saved-mesh/auto_dist_"):
    # weight_file = "./pretrained/bloom-saved/model_state.pdparams"
    state_dict = paddle.load(state_dict)
    os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
    
    result_state_dict_file = path_prefix + f"{n_rank}.pdparams"
    if os.path.exists(result_state_dict_file):
        splited_state_dict = paddle.load(result_state_dict_file)
    else:
        splited_state_dict = {}

    splits = get_bloom_mesh_info(rank_size, n_rank, layer_size)

    for split in splits:
        assert split.name not in splited_state_dict, f"{split.name} must not be in splited state_dict, which should be"
        if split.name not in state_dict:
            continue

        tensor = state_dict.pop(split.name)
        if not paddle.is_tensor(tensor):
            tensor = paddle.to_tensor(tensor)

        tensor = split(tensor)
        
        splited_state_dict[split.name] = tensor
    
    assert len(state_dict) == 0, f"权重中的参数没有便利完毕，还剩余：{state_dict.keys()}"
    paddle.save(splited_state_dict, result_state_dict_file)


def gen_split_weight_file_with_whole(state_dict, rank_size = 2, n_rank = 1, layer_size = 24, path_prefix: str = "./pretrained/bloom-saved-mesh/auto_dist_"):
    # weight_file = "./pretrained/bloom-saved/model_state.pdparams"
    if isinstance(state_dict, str):
        state_dict = paddle.load(state_dict)

    os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
    
    result_state_dict_file = path_prefix + f"{n_rank}.pdparams"
    if os.path.exists(result_state_dict_file):
        splited_state_dict = paddle.load(result_state_dict_file)
    else:
        splited_state_dict = {}

    splits = get_bloom_mesh_info(rank_size, n_rank, layer_size)

    has_bloom_prefix = list(state_dict.keys())[0].startswith("bloom.")
    if not has_bloom_prefix:
        state_dict = {"bloom." + key: value for key, value in state_dict.items()}

    for split in splits:
        # print(f"start to do spliting<{split}>")

        assert split.name not in splited_state_dict, f"{split.name} must not be in splited state_dict"
        if split.name not in state_dict:
            # print(f"name<{split.name}> not in splited_state_dict<{splited_state_dict.keys()}>")
            continue

        tensor = state_dict[split.name]
        if not paddle.is_tensor(tensor):
            tensor = paddle.to_tensor(tensor)

        tensor = split(tensor)
        
        splited_state_dict[split.name] = tensor
        # print(f"splited tensor<{split.name}> -info <{tensor_info(tensor)}> ")

    # print(f"the final result in rank<{n_rank}>")
    # print({key: value.shape for key, value in splited_state_dict.items()})
    paddle.save(splited_state_dict, result_state_dict_file)


import threading
def run_176b_split():
    state_dict = {}
    threads = []
    thread_lock = threading.Lock()
    def load_state_dict(file):
        rank_state_dict = paddle.load(file)
        thread_lock.acquire()
        state_dict.update(rank_state_dict)
        thread_lock.release()

    for i in trange(1, 73, desc="loading state dict"):
        weight_file = f"./bloom-176b-paddle/model_state_{i}.pdparams"
        load_state_dict(weight_file)
        # thread = threading.Thread(target=load_state_dict, args=(weight_file,))
        # threads.append(thread)
        # thread.start()
    
    # for thread in tqdm(threads):
    #     thread.join()
    
    for rank in trange(6, 8, desc="gen mesh weight file"):
        gen_split_weight_file_with_whole(
            state_dict, rank_size = 8, n_rank = rank, layer_size = 70,
            path_prefix="./bloom-176b-paddle-mesh/auto_dist"
        )

# run_176b_split()        

def run_560m_split():
    weight_file = "./pretrained/saved-bloom-560m/model_state.pdparams"
    rank_size = 2
    state_dict = paddle.load(weight_file)

    for rank in trange(rank_size):
        gen_split_weight_file_with_whole(
            state_dict, rank_size = rank_size, n_rank = rank, layer_size = 24,
            path_prefix=f"./pretrained/saved-bloom-560m-mesh-{rank_size}/auto_dist"
        )

run_560m_split()
import requests
requests.get("http://0.0.0.0:8084/all-sync")
